Skip to content

Commit

Permalink
Complete the forward pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 28, 2024
1 parent 1897c32 commit 20c2e79
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion candle-transformers/src/models/clip/text_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ impl ClipTextTransformer {
})
}

// TODO: rewrrite to newer version
// TODO: rewrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,
Expand Down
71 changes: 57 additions & 14 deletions candle-transformers/src/models/siglip.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![allow(unused)]
use candle::{Result, Tensor, D};
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};

// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
Expand Down Expand Up @@ -90,9 +89,6 @@ pub struct Config {
pub vision_config: VisionConfig,
}

#[derive(Debug, Clone)]
struct MultiheadAttentionPoolingHead {}

#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
Expand Down Expand Up @@ -163,7 +159,6 @@ impl Mlp {
fn new<C: TransformerConfig>(cfg: &C, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size();
let intermediate_size = cfg.intermediate_size();
let hidden_act = cfg.hidden_act();
let fc1 = candle_nn::linear(hidden_size, intermediate_size, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(intermediate_size, hidden_size, vb.pp("fc2"))?;
Ok(Self {
Expand All @@ -174,7 +169,7 @@ impl Mlp {
}
}

impl candle::Module for Mlp {
impl Module for Mlp {
fn forward(&self, xs: &candle::Tensor) -> Result<candle::Tensor> {
xs.apply(&self.fc1)?
.apply(&self.activation_fn)?
Expand Down Expand Up @@ -275,9 +270,9 @@ impl VisionEmbeddings {
}
}

impl candle::Module for VisionEmbeddings {
impl Module for VisionEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_, _, height, width) = xs.dims4()?;
let (_batch, _channels, _height, _width) = xs.dims4()?;
let embeddings = xs
.apply(&self.patch_embedding)?
.flatten_from(2)?
Expand Down Expand Up @@ -309,7 +304,7 @@ impl VisionTransformer {
}
}

impl candle::Module for VisionTransformer {
impl Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let xs = self.encoder.forward(&xs, None)?;
Expand All @@ -323,13 +318,13 @@ pub struct VisionModel {
}

impl VisionModel {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionTransformer::new(cfg, vb.pp("vision_model"))?;
Ok(Self { vision_model })
}
}

impl candle::Module for VisionModel {
impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.vision_model)
}
Expand Down Expand Up @@ -361,7 +356,7 @@ impl TextEmbeddings {
}
}

impl candle::Module for TextEmbeddings {
impl Module for TextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = self.token_embedding.forward(input_ids)?;
Expand Down Expand Up @@ -396,6 +391,48 @@ impl TextTransformer {
head,
})
}

// TODO: rewrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,
mask_after: usize,
device: &candle::Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if j > i || j > mask_after {
f32::MIN
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, 1, seq_len, seq_len))
}

pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self
.encoder
.forward(&input_ids, Some(&causal_attention_mask))?;
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
last_hidden_state
.i((.., seq_len - 1, ..))?
.apply(&self.head)
}
}

impl Module for TextTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward_with_mask(xs, usize::MAX)
}
}

#[derive(Debug, Clone)]
Expand All @@ -404,8 +441,14 @@ pub struct TextModel {
}

impl TextModel {
fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &TextConfig, vb: VarBuilder) -> Result<Self> {
let text_model = TextTransformer::new(cfg, vb.pp("text_model"))?;
Ok(Self { text_model })
}
}

impl Module for TextModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.text_model)
}
}

0 comments on commit 20c2e79

Please sign in to comment.