diff --git a/candle-examples/examples/mobileclip/main.rs b/candle-examples/examples/mobileclip/main.rs index d505fc7c48..d9615c43b8 100644 --- a/candle-examples/examples/mobileclip/main.rs +++ b/candle-examples/examples/mobileclip/main.rs @@ -60,7 +60,6 @@ fn load_images>( image_size: usize, ) -> anyhow::Result { let mut images = vec![]; - for path in paths { let tensor = candle_examples::imagenet::load_image_with_std_mean( path, @@ -70,9 +69,7 @@ fn load_images>( )?; images.push(tensor); } - let images = Tensor::stack(&images, 0)?; - Ok(images) } @@ -80,24 +77,17 @@ pub fn main() -> anyhow::Result<()> { let args = Args::parse(); let model_name = args.which.model_name(); - let api = hf_hub::api::sync::Api::new()?; let api = api.model(model_name); - let model_file = if args.use_pth { api.get("open_clip_pytorch_model.bin")? } else { api.get("open_clip_model.safetensors")? }; - let tokenizer = api.get("tokenizer.json")?; - let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; - let config = &args.which.config(); - let device = candle_examples::device(args.cpu)?; - let vec_imgs = match args.images { Some(imgs) => imgs, None => vec![ @@ -105,9 +95,7 @@ pub fn main() -> anyhow::Result<()> { "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), ], }; - let images = load_images(&vec_imgs, config.image_size)?.to_device(&device)?; - let vb = if args.use_pth { VarBuilder::from_pth(&model_file, DType::F32, &device)? } else { @@ -115,22 +103,15 @@ pub fn main() -> anyhow::Result<()> { }; let model = mobileclip::MobileClipModel::new(vb, config)?; - let (input_ids, vec_seq) = tokenize_sequences(args.sequences, &tokenizer, &device)?; - let (_logits_per_text, logits_per_image) = model.forward(&images, &input_ids)?; - let softmax_image = softmax(&logits_per_image, 1)?; - let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; - println!("softmax_image_vec: {:?}", softmax_image_vec); - let probability_vec = softmax_image_vec .iter() .map(|v| v * 100.0) .collect::>(); - let probability_per_image = probability_vec.len() / vec_imgs.len(); for (i, img) in vec_imgs.iter().enumerate() { @@ -171,7 +152,6 @@ pub fn tokenize_sequences( }; let mut tokens = vec![]; - for seq in vec_seq.clone() { let encoding = tokenizer.encode(seq, true).map_err(E::msg)?; tokens.push(encoding.get_ids().to_vec()); @@ -185,8 +165,6 @@ pub fn tokenize_sequences( token_vec.extend(vec![pad_id; len_diff]); } } - let input_ids = Tensor::new(tokens, device)?; - Ok((input_ids, vec_seq)) } diff --git a/candle-transformers/src/models/fastvit.rs b/candle-transformers/src/models/fastvit.rs index 8199874276..b7bdaf888a 100644 --- a/candle-transformers/src/models/fastvit.rs +++ b/candle-transformers/src/models/fastvit.rs @@ -495,7 +495,6 @@ fn fastvit_model(cfg: &Config, nclasses: Option, vb: VarBuilder) -> Resul .apply(&stage3)? .apply(&stage4)? .apply(&final_conv)?; - match &cls { None => Ok(xs), Some(cls) => xs.mean(D::Minus2)?.mean(D::Minus1)?.apply(cls), diff --git a/candle-transformers/src/models/mobileclip.rs b/candle-transformers/src/models/mobileclip.rs index 4953d835b5..45a5dbad9f 100644 --- a/candle-transformers/src/models/mobileclip.rs +++ b/candle-transformers/src/models/mobileclip.rs @@ -22,7 +22,6 @@ impl MobileClipConfig { pub fn s1() -> Self { let text_config = text_model::Config::vit_base_patch32(); let vision_config = fastvit::Config::mci1(); - Self { text_config, vision_config, @@ -32,7 +31,6 @@ impl MobileClipConfig { pub fn s2() -> Self { let text_config = text_model::Config::vit_base_patch32(); let vision_config = fastvit::Config::mci2(); - Self { text_config, vision_config, @@ -45,12 +43,10 @@ impl MobileClipModel { pub fn new(vs: VarBuilder, c: &MobileClipConfig) -> Result { let vision_model = fastvit::fastvit(&c.vision_config, 512, vs.pp("visual.trunk"))?; let text_model = text_model::OpenClipTextTransformer::new(vs.pp("text"), &c.text_config)?; - let text_projection = vs.get( (c.text_config.embed_dim, c.text_config.projection_dim), "text.text_projection", )?; - let logit_scale = vs.get(&[], "logit_scale")?; Ok(Self { text_model,