From 75d325ac5f21326c797936fdd3e09b9761f7ea5b Mon Sep 17 00:00:00 2001 From: Laurent Date: Mon, 30 Sep 2024 21:05:41 +0200 Subject: [PATCH 1/2] Pixtral polishing. --- candle-examples/examples/pixtral/main.rs | 15 +++-------- .../src/models/pixtral/llava.rs | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs index 8e48b60b9a..79f438686f 100644 --- a/candle-examples/examples/pixtral/main.rs +++ b/candle-examples/examples/pixtral/main.rs @@ -73,22 +73,18 @@ impl TextGeneration { let img_break = get_token("[IMG_BREAK]")?; let img_end = get_token("[IMG_END]")?; let start_gen = std::time::Instant::now(); - let mut pos = 0; for index in 0..sample_len { let logits = if index > 0 { let context_size = if index > 0 { 1 } else { tokens.len() }; let start_pos = tokens.len().saturating_sub(context_size); let ctxt = &tokens[start_pos..]; let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.language_model.forward(&input, pos)?; - pos += context_size; - logits + self.model.lm_forward(&input)? } else { let (_b, _c, h, w) = self.image.dims4()?; let h = h / self.model.patch_size; let w = w / self.model.patch_size; - let image_embeds = self.model.vision_tower.forward(&self.image)?; - let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?; + let image_embeds = self.model.encode_image(&self.image)?; println!("generated image embeddings {image_embeds:?}"); let image_embeds = image_embeds.to_dtype(self.model.dtype)?; for &t in tokens.iter() { @@ -124,12 +120,7 @@ impl TextGeneration { input_embeds.push(end_embeds); let input_embeds = Tensor::cat(&input_embeds, 1)?; - let logits = self - .model - .language_model - .forward_embeds(&input_embeds, None, pos)?; - pos += input_embeds.dim(1)?; - logits + self.model.lm_forward_embeds(&input_embeds)? }; let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; let logits = if self.repeat_penalty == 1. { diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs index 33e0aca0b9..802d87b160 100644 --- a/candle-transformers/src/models/pixtral/llava.rs +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -48,6 +48,7 @@ pub struct Model { pub vision_tower: vision_model::Model, pub patch_size: usize, pub dtype: candle::DType, + pub pos: usize, } impl Model { @@ -67,6 +68,31 @@ impl Model { vision_tower, patch_size: cfg.vision_config.patch_size, dtype: vb.dtype(), + pos: 0, }) } + + pub fn clear_kv_cache(&mut self) { + self.language_model.clear_kv_cache(); + self.pos = 0; + } + + pub fn encode_image(&self, image: &Tensor) -> Result { + let image_embeds = self.vision_tower.forward(&image)?; + self.multi_modal_projector.forward(&image_embeds) + } + + pub fn lm_forward(&mut self, input_ids: &Tensor) -> Result { + let (_, seq_len) = input_ids.dims2()?; + let logits = self.language_model.forward(input_ids, self.pos)?; + self.pos += seq_len; + Ok(logits) + } + + pub fn lm_forward_embeds(&mut self, xs: &Tensor) -> Result { + let (_, seq_len, _) = xs.dims3()?; + let logits = self.language_model.forward_embeds(xs, None, self.pos)?; + self.pos += seq_len; + Ok(logits) + } } From 31ed53e33c506b46799babda97a8ae5a802a2eab Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 30 Sep 2024 21:14:36 +0200 Subject: [PATCH 2/2] Clippy fix. --- candle-transformers/src/models/pixtral/llava.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-transformers/src/models/pixtral/llava.rs b/candle-transformers/src/models/pixtral/llava.rs index 802d87b160..4aff26a784 100644 --- a/candle-transformers/src/models/pixtral/llava.rs +++ b/candle-transformers/src/models/pixtral/llava.rs @@ -78,7 +78,7 @@ impl Model { } pub fn encode_image(&self, image: &Tensor) -> Result { - let image_embeds = self.vision_tower.forward(&image)?; + let image_embeds = self.vision_tower.forward(image)?; self.multi_modal_projector.forward(&image_embeds) }