diff --git a/candle-transformers/src/models/llava/mod.rs b/candle-transformers/src/models/llava/mod.rs index caa8737af6..1ed3b50c63 100644 --- a/candle-transformers/src/models/llava/mod.rs +++ b/candle-transformers/src/models/llava/mod.rs @@ -279,7 +279,7 @@ impl LLaVA { (), ))? } else { - todo!("not implemented in original python LLaVA yet") + bail!("not implemented in original python LLaVA yet") }; let new_image_feature = if mm_patch_merge_type.contains("unpad") { let new_image_feature = new_image_feature diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs new file mode 100644 index 0000000000..5888fc81b6 --- /dev/null +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -0,0 +1,37 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +#[derive(Debug, Clone)] +pub struct Config { + pub projector_hidden_act: candle_nn::Activation, + pub text_config_hidden_size: usize, + pub vision_config_hidden_size: usize, +} + +#[derive(Debug, Clone)] +pub struct MultiModalProjector { + linear_1: Linear, + act: candle_nn::Activation, + linear_2: Linear, +} + +impl MultiModalProjector { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let (hidden_v, hidden_t) = (cfg.vision_config_hidden_size, cfg.text_config_hidden_size); + let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?; + let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?; + Ok(Self { + linear_1, + act: cfg.projector_hidden_act, + linear_2, + }) + } +} + +impl Module for MultiModalProjector { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.linear_1)? + .apply(&self.act)? + .apply(&self.linear_2) + } +} diff --git a/candle-transformers/src/models/pixtral/mod.rs b/candle-transformers/src/models/pixtral/mod.rs index 2f3a4ce300..b9ef83ad25 100644 --- a/candle-transformers/src/models/pixtral/mod.rs +++ b/candle-transformers/src/models/pixtral/mod.rs @@ -1 +1,2 @@ +pub mod llava; pub mod vision_model;