Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 28, 2024
1 parent f26aa5e commit faa20ef
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions candle-examples/examples/siglip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub fn main() -> anyhow::Result<()> {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
let model = siglip::Model::new(&config, vb)?;
let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?;
let (input_ids, vec_seq) = tokenize_sequences(&config, args.sequences, &tokenizer, &device)?;
let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?;
let softmax_image = softmax(&logits_per_image, 1)?;
let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::<f32>()?;
Expand All @@ -111,11 +111,7 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
let tokenizer = match tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"openai/clip-vit-base-patch32".to_string(),
hf_hub::RepoType::Model,
"refs/pr/15".to_string(),
));
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
Expand All @@ -125,14 +121,12 @@ pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
}

pub fn tokenize_sequences(
config: &siglip::Config,
sequences: Option<Vec<String>>,
tokenizer: &Tokenizer,
device: &Device,
) -> anyhow::Result<(Tensor, Vec<String>)> {
let pad_id = *tokenizer
.get_vocab(true)
.get("<|endoftext|>")
.ok_or(E::msg("No pad token"))?;
let pad_id = config.text_config.pad_token_id;
let vec_seq = match sequences {
Some(seq) => seq,
None => vec![
Expand Down

0 comments on commit faa20ef

Please sign in to comment.