Skip to content

Commit

Permalink
Add the llava multimodal adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 30, 2024
1 parent d8d2ada commit 3a2df9b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion candle-transformers/src/models/llava/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ impl LLaVA {
(),
))?
} else {
todo!("not implemented in original python LLaVA yet")
bail!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature
Expand Down
37 changes: 37 additions & 0 deletions candle-transformers/src/models/pixtral/llava.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};

#[derive(Debug, Clone)]
pub struct Config {
pub projector_hidden_act: candle_nn::Activation,
pub text_config_hidden_size: usize,
pub vision_config_hidden_size: usize,
}

#[derive(Debug, Clone)]
pub struct MultiModalProjector {
linear_1: Linear,
act: candle_nn::Activation,
linear_2: Linear,
}

impl MultiModalProjector {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (hidden_v, hidden_t) = (cfg.vision_config_hidden_size, cfg.text_config_hidden_size);
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
Ok(Self {
linear_1,
act: cfg.projector_hidden_act,
linear_2,
})
}
}

impl Module for MultiModalProjector {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.linear_1)?
.apply(&self.act)?
.apply(&self.linear_2)
}
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/pixtral/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod llava;
pub mod vision_model;

0 comments on commit 3a2df9b

Please sign in to comment.