Skip to content

Commit

Permalink
Get everything in place.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 28, 2024
1 parent fbf06c4 commit 9c99c11
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 37 deletions.
2 changes: 1 addition & 1 deletion candle-examples/examples/siglip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn tokenize_sequences(
let encoding = tokenizer.encode(seq, true).map_err(E::msg)?;
tokens.push(encoding.get_ids().to_vec());
}
let max_len = tokens.iter().map(|v| v.len()).max().unwrap_or(0);
let max_len = config.text_config.max_position_embeddings;
// Pad the sequences to have the same length
for token_vec in tokens.iter_mut() {
let len_diff = max_len - token_vec.len();
Expand Down
41 changes: 5 additions & 36 deletions candle-transformers/src/models/siglip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,37 +530,12 @@ impl TextTransformer {
head,
})
}

// TODO: rewrite to newer version
fn build_causal_attention_mask(
bsz: usize,
seq_len: usize,
mask_after: usize,
device: &candle::Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| {
(0..seq_len).map(move |j| {
if j > i || j > mask_after {
f32::MIN
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
mask.broadcast_as((bsz, 1, seq_len, seq_len))
}

pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?;
}
impl Module for TextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (_bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self
.encoder
.forward(&input_ids, Some(&causal_attention_mask))?;
let input_ids = self.encoder.forward(&input_ids, None)?;
let last_hidden_state = self.final_layer_norm.forward(&input_ids)?;
last_hidden_state
.i((.., seq_len - 1, ..))?
Expand All @@ -569,12 +544,6 @@ impl TextTransformer {
}
}

impl Module for TextTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward_with_mask(xs, usize::MAX)
}
}

#[derive(Debug, Clone)]
pub struct TextModel {
pub text_model: TextTransformer,
Expand Down

0 comments on commit 9c99c11

Please sign in to comment.