diff --git a/candle-examples/examples/chinese_clip/main.rs b/candle-examples/examples/chinese_clip/main.rs new file mode 100644 index 0000000000..5cee1fc81e --- /dev/null +++ b/candle-examples/examples/chinese_clip/main.rs @@ -0,0 +1,224 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, Device, Tensor}; +use candle_nn as nn; +use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel}; +use clap::Parser; +use tokenizers::Tokenizer; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + tokenizer: Option, + + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + tracing_subscriber::fmt::init(); + + let device = candle_examples::device(args.cpu)?; + let var = load_weights(args.model, &device)?; + let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?; + tracing::info!("Transformer loaded. "); + + let (pixel_values, vec_imgs) = load_images(args.images, &device)?; + tracing::info!("Images loaded. "); + + let tokenizer = load_tokenizer()?; + let (input_ids, type_ids, attention_mask, text_sequences) = + tokenize_sequences(args.sequences, &tokenizer, &device)?; + + tracing::info!("Computing ... "); + let (_logits_per_text, logits_per_image) = clip_model.forward( + &pixel_values, + &input_ids, + Some(&type_ids), + Some(&attention_mask), + )?; + let softmax_image = nn::ops::softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + tracing::info!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]); + } + } + + Ok(()) +} + +pub fn load_weights(model: Option, device: &Device) -> anyhow::Result { + let model_file = match model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = hf_hub::Repo::with_revision( + "OFA-Sys/chinese-clip-vit-base-patch16".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + ); + let api = api.repo(repo); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? }) +} + +pub fn load_tokenizer() -> anyhow::Result { + let tokenizer_file = { + let api = hf_hub::api::sync::Api::new()?; + let repo = hf_hub::Repo::with_revision( + "OFA-Sys/chinese-clip-vit-base-patch16".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + ); + let api = api.repo(repo); + api.get("tokenizer.json")? + }; + + Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec)> { + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "自行车比赛".to_string(), + "两只猫咪".to_string(), + "拿着蜡烛的机器人".to_string(), + ], + }; + + let mut input_ids = vec![]; + let mut type_ids = vec![]; + let mut attention_mask = vec![]; + let mut max_len = 0; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?; + input_ids.push(encoding.get_ids().to_vec()); + type_ids.push(encoding.get_type_ids().to_vec()); + attention_mask.push(encoding.get_attention_mask().to_vec()); + if encoding.get_ids().len() > max_len { + max_len = encoding.get_ids().len(); + } + } + + let pad_id = *tokenizer + .get_vocab(true) + .get("[PAD]") + .ok_or(anyhow::Error::msg("No pad token"))?; + + let input_ids: Vec> = input_ids + .iter_mut() + .map(|item| { + item.extend(vec![pad_id; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let type_ids: Vec> = type_ids + .iter_mut() + .map(|item| { + item.extend(vec![0; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let attention_mask: Vec> = attention_mask + .iter_mut() + .map(|item| { + item.extend(vec![0; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let input_ids = Tensor::new(input_ids, device)?; + let type_ids = Tensor::new(type_ids, device)?; + let attention_mask = Tensor::new(attention_mask, device)?; + + Ok((input_ids, type_ids, attention_mask, vec_seq)) +} + +pub fn load_images( + images: Option>, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + let vec_imgs = match images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + let mut images = vec![]; + + for path in vec_imgs.iter() { + let tensor = load_image(path, 224, device)?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?.to_device(device)?; + Ok((images, vec_imgs)) +} + +fn load_image>( + path: T, + image_size: usize, + device: &Device, +) -> anyhow::Result { + let img = image::ImageReader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + + let img = img.to_rgb8().into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?; + let std = + Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?; + let img = (img.to_dtype(DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std)?; + + Ok(img) +} diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs new file mode 100644 index 0000000000..88472f0b88 --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -0,0 +1,208 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use text_model::ChineseClipTextTransformer; +use vision_model::ChineseClipVisionTransformer; + +pub mod text_model; +pub mod vision_model; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, + Gelu, + GeluNew, + Relu, +} + +impl From for Activation { + fn from(value: String) -> Self { + match value.as_str() { + "quick_gelu" => Activation::QuickGelu, + "gelu" => Activation::Gelu, + "gelu_new" => Activation::GeluNew, + "relu" => Activation::Relu, + _ => panic!("Invalid activation function: {}", value), + } + } +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + Activation::Gelu => xs.gelu_erf(), + Activation::GeluNew => xs.gelu(), + Activation::Relu => xs.relu(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipConfig { + pub text_config: text_model::ChineseClipTextConfig, + pub vision_config: vision_model::ChineseClipVisionConfig, + pub projection_dim: usize, + pub logit_scale_init_value: f32, + pub image_size: usize, +} + +impl ChineseClipConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + let text_config = text_model::ChineseClipTextConfig::clip_vit_base_patch16(); + let vision_config = vision_model::ChineseClipVisionConfig::clip_vit_base_patch16(); + + Self { + text_config, + vision_config, + projection_dim: 512, + logit_scale_init_value: 2.6592, + image_size: 512, + } + } +} + +#[derive(Clone, Debug)] +pub enum EncoderConfig { + Text(text_model::ChineseClipTextConfig), + Vision(vision_model::ChineseClipVisionConfig), +} + +impl EncoderConfig { + pub fn embed_dim(&self) -> usize { + match self { + Self::Text(c) => c.hidden_size, + Self::Vision(c) => c.hidden_size, + } + } + + pub fn num_attention_heads(&self) -> usize { + match self { + Self::Text(c) => c.num_attention_heads, + Self::Vision(c) => c.num_attention_heads, + } + } + + pub fn intermediate_size(&self) -> usize { + match self { + Self::Text(c) => c.intermediate_size, + Self::Vision(c) => c.intermediate_size, + } + } + + pub fn num_hidden_layers(&self) -> usize { + match self { + Self::Text(c) => c.num_hidden_layers, + Self::Vision(c) => c.num_hidden_layers, + } + } + + pub fn activation(&self) -> Activation { + match self { + Self::Text(c) => c.hidden_act, + Self::Vision(c) => c.hidden_act, + } + } + + pub fn layer_norm_eps(&self) -> f64 { + match self { + Self::Text(c) => c.layer_norm_eps, + Self::Vision(c) => c.layer_norm_eps, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipModel { + text_model: ChineseClipTextTransformer, + vision_model: ChineseClipVisionTransformer, + visual_projection: nn::Linear, + text_projection: nn::Linear, + logit_scale: Tensor, +} + +impl ChineseClipModel { + pub fn new(vs: nn::VarBuilder, c: &ChineseClipConfig) -> Result { + let text_model = ChineseClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; + + let vision_model = + ChineseClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; + + let vision_embed_dim = c.vision_config.hidden_size; + let vision_projection = nn::linear_no_bias( + vision_embed_dim, + c.projection_dim, + vs.pp("visual_projection"), + )?; + + let text_embed_dim = c.text_config.hidden_size; + let text_projection = + nn::linear_no_bias(text_embed_dim, c.projection_dim, vs.pp("text_projection"))?; + + let logit_scale = if vs.contains_tensor("logit_scale") { + vs.get(&[], "logit_scale")? + } else { + Tensor::new(&[c.logit_scale_init_value], vs.device())? + }; + + Ok(Self { + text_model, + vision_model, + visual_projection: vision_projection, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features( + &self, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let output = self + .text_model + .forward(input_ids, token_type_ids, attention_mask)?; + self.text_projection.forward(&output) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values + .apply(&self.vision_model)? + .apply(&self.visual_projection) + } + + pub fn forward( + &self, + pixel_values: &Tensor, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids, token_type_ids, attention_mask)?; + + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/chinese_clip/text_model.rs b/candle-transformers/src/models/chinese_clip/text_model.rs new file mode 100644 index 0000000000..19499709a7 --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/text_model.rs @@ -0,0 +1,540 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{DType, Device, IndexOp, Module, Result, Tensor}; +use candle_nn as nn; + +use super::Activation; + +/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For +/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to +/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). +/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models +/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). +#[derive(Clone, Debug)] +pub enum PositionEmbeddingType { + Absolute, + RelativeKey, + RelativeKeyQuery, +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: Activation, + pub hidden_dropout_prob: f32, + pub attention_probs_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub initializer_factor: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, + pub position_embedding_type: PositionEmbeddingType, + pub use_cache: bool, +} + +impl Default for ChineseClipTextConfig { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: Activation::Gelu, + hidden_dropout_prob: 0.1, + attention_probs_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + initializer_factor: 1.0, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + } + } +} + +impl ChineseClipTextConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + Self { + vocab_size: 21128, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: Activation::Gelu, + hidden_dropout_prob: 0.1, + attention_probs_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + initializer_factor: 1.0, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextEmbeddings { + word_embeddings: nn::Embedding, + position_embeddings: nn::Embedding, + token_type_embeddings: nn::Embedding, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + position_embedding_type: PositionEmbeddingType, + position_ids: Tensor, + token_type_ids: Tensor, +} + +impl ChineseClipTextEmbeddings { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let word_embeddings = nn::embedding( + config.vocab_size, + config.hidden_size, + var.pp("word_embeddings"), + )?; + let position_embeddings = nn::embedding( + config.max_position_embeddings, + config.hidden_size, + var.pp("position_embeddings"), + )?; + let token_type_embeddings = nn::embedding( + config.type_vocab_size, + config.hidden_size, + var.pp("token_type_embeddings"), + )?; + let layer_norm = nn::layer_norm::( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + let position_ids = + Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())? + .unsqueeze(0)?; + let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?; + + Ok(Self { + word_embeddings, + position_embeddings, + token_type_embeddings, + layer_norm, + dropout, + position_embedding_type: config.position_embedding_type.clone(), + position_ids, + token_type_ids, + }) + } + + fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result { + let (_batch_size, seq_length) = xs.dims2()?; + let position_ids = (0..seq_length as u32).collect::>(); + let position_ids = self.position_ids.index_select( + &Tensor::new(&position_ids[..], self.position_ids.device())?, + 1, + )?; + + let word_embeddings = self.word_embeddings.forward(xs)?; + + let token_type_ids = match token_type_ids { + Some(token_type_ids) => token_type_ids, + None => &self.token_type_ids.i((.., 0..seq_length))?, + }; + let token_type_ids = token_type_ids.expand(xs.shape())?; + let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?; + + let embeddings = (&word_embeddings + token_type_embeddings)?; + let embeddings = match self.position_embedding_type { + PositionEmbeddingType::Absolute => { + let position_embeddings = self.position_embeddings.forward(&position_ids)?; + let position_embeddings = position_embeddings.expand(embeddings.shape())?; + (embeddings + position_embeddings)? + } + _ => embeddings, + }; + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings, false)?; + Ok(embeddings) + } +} + +/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`] +#[derive(Clone, Debug)] +struct ChineseClipTextSelfOutput { + dense: nn::Linear, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + span: tracing::Span, +} + +impl ChineseClipTextSelfOutput { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?; + let layer_norm = nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states, false)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`] +#[derive(Clone, Debug)] +struct ChineseClipTextSelfAttention { + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + dropout: nn::Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl ChineseClipTextSelfAttention { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?; + let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?; + let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_scores = attention_scores.broadcast_add(attention_mask)?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + nn::ops::softmax(&attention_scores, candle::D::Minus1)? + }; + let attention_probs = self.dropout.forward(&attention_probs, false)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; + Ok(context_layer) + } +} + +/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`] +#[derive(Clone, Debug)] +struct ChineseClipTextAttention { + self_attention: ChineseClipTextSelfAttention, + self_output: ChineseClipTextSelfOutput, + span: tracing::Span, +} + +impl ChineseClipTextAttention { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?; + let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +type HiddenActLayer = Activation; + +/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`] +#[derive(Clone, Debug)] +struct ChineseClipTextIntermediate { + dense: nn::Linear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl ChineseClipTextIntermediate { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear( + config.hidden_size, + config.intermediate_size, + var.pp("dense"), + )?; + Ok(Self { + dense, + intermediate_act: config.hidden_act, + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } +} + +impl Module for ChineseClipTextIntermediate { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`] +#[derive(Clone, Debug)] +struct ChineseClipTextOutput { + dense: nn::Linear, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + span: tracing::Span, +} + +impl ChineseClipTextOutput { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear( + config.intermediate_size, + config.hidden_size, + var.pp("dense"), + )?; + let layer_norm = nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states, false)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`] +#[derive(Clone, Debug)] +struct ChineseClipTextLayer { + attention: ChineseClipTextAttention, + intermediate: ChineseClipTextIntermediate, + output: ChineseClipTextOutput, + span: tracing::Span, +} + +impl ChineseClipTextLayer { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?; + let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?; + let output = ChineseClipTextOutput::new(var.pp("output"), config)?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states, attention_mask)?; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +#[derive(Clone, Debug)] +struct Tanh; + +impl Tanh { + pub fn new() -> Self { + Self {} + } +} +impl Module for Tanh { + fn forward(&self, xs: &Tensor) -> Result { + xs.tanh() + } +} + +#[derive(Clone, Debug)] +struct ChineseClipTextPooler { + dense: nn::Linear, + activation: Tanh, +} + +impl ChineseClipTextPooler { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?; + let activation = Tanh::new(); + Ok(Self { dense, activation }) + } +} + +impl Module for ChineseClipTextPooler { + fn forward(&self, hidden_states: &Tensor) -> Result { + let first_token_tensor = hidden_states.i((.., 0))?; + let pooled_output = self.dense.forward(&first_token_tensor)?; + let pooled_output = self.activation.forward(&pooled_output)?; + Ok(pooled_output) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipTextEncoder { + layers: Vec, + span: tracing::Span, +} + +impl ChineseClipTextEncoder { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(ChineseClipTextEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_mask)? + } + Ok(hidden_states) + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextTransformer { + embeddings: ChineseClipTextEmbeddings, + encoder: ChineseClipTextEncoder, + pooler: Option, + pub device: Device, + span: tracing::Span, +} + +impl ChineseClipTextTransformer { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?; + let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?; + // see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362 + // In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file. + let pooler = if var.contains_tensor("pooler") { + Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?) + } else { + None + }; + Ok(Self { + embeddings, + encoder, + pooler, + device: var.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let attention_mask = match attention_mask { + Some(attention_mask) => attention_mask.clone(), + None => input_ids.ones_like()?, + }; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995 + let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?; + let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?; + let encoder_output = encoder_outputs.i((.., 0, ..))?; + let pooled_output = match &self.pooler { + Some(pooler) => pooler.forward(&encoder_output)?, + None => encoder_output, + }; + + Ok(pooled_output) + } +} + +fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result { + let attention_mask = match attention_mask.rank() { + 3 => attention_mask.unsqueeze(1)?, + 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?, + _ => candle::bail!("Wrong shape for input_ids or attention_mask"), + }; + let attention_mask = attention_mask.to_dtype(dtype)?; + // torch.finfo(dtype).min + (attention_mask.ones_like()? - &attention_mask)? + .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) +} diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs new file mode 100644 index 0000000000..2d345e0f4a --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/vision_model.rs @@ -0,0 +1,385 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D}; +use candle_nn as nn; + +use super::{Activation, EncoderConfig}; + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub projection_dim: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, + pub hidden_act: Activation, + pub layer_norm_eps: f64, + pub attention_dropout: f32, + pub initializer_range: f32, + pub initializer_factor: f32, +} + +impl Default for ChineseClipVisionConfig { + fn default() -> Self { + ChineseClipVisionConfig { + hidden_size: 768, + intermediate_size: 3072, + projection_dim: 512, + num_hidden_layers: 12, + num_attention_heads: 12, + num_channels: 3, + image_size: 224, + patch_size: 32, + hidden_act: Activation::QuickGelu, + layer_norm_eps: 1e-5, + attention_dropout: 0.0, + initializer_range: 0.02, + initializer_factor: 1.0, + } + } +} + +impl ChineseClipVisionConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + Self { + hidden_size: 768, + intermediate_size: 3072, + projection_dim: 512, + num_hidden_layers: 12, + num_attention_heads: 12, + num_channels: 3, + image_size: 224, + patch_size: 16, + hidden_act: Activation::QuickGelu, + layer_norm_eps: 1e-5, + attention_dropout: 0.0, + initializer_range: 0.02, + initializer_factor: 1.0, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionEmbeddings { + patch_embedding: nn::Conv2d, + position_ids: Tensor, + class_embedding: Tensor, + position_embedding: nn::Embedding, +} + +impl ChineseClipVisionEmbeddings { + pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result { + let embed_dim = config.hidden_size; + // originally nn.Parameter + let class_embedding = if var.contains_tensor("class_embedding") { + var.get(embed_dim, "class_embedding")? + } else { + Tensor::randn(0f32, 1f32, embed_dim, var.device())? + }; + + let num_patches = (config.image_size / config.patch_size).pow(2); + let num_positions = num_patches + 1; + let position_ids = Tensor::arange(0, num_positions as i64, var.device())?; + + let conv2dconfig = nn::Conv2dConfig { + stride: config.patch_size, + ..Default::default() + }; + let position_embedding = + nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?; + let patch_embedding = nn::conv2d_no_bias( + config.num_channels, + embed_dim, + config.patch_size, + conv2dconfig, + var.pp("patch_embedding"), + )?; + Ok(Self { + patch_embedding, + position_ids, + class_embedding, + position_embedding, + }) + } +} + +impl Module for ChineseClipVisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let batch_size = xs.shape().dims(); + let patch_embeds = self + .patch_embedding + .forward(xs)? + .flatten_from(2)? + .transpose(1, 2)?; + let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?)); + let class_embeds = self.class_embedding.expand(shape)?; + let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + embeddings.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionAttention { + k_proj: nn::Linear, + v_proj: nn::Linear, + q_proj: nn::Linear, + out_proj: nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ChineseClipVisionAttention { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let embed_dim = config.embed_dim(); + let num_attention_heads = config.num_attention_heads(); + let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?; + let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?; + let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?; + let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(ChineseClipVisionAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + + let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { + attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)? + .reshape((bsz * self.num_attention_heads, seq_len, src_len))? + } else { + attn_weights + }; + + let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionMlp { + fc1: nn::Linear, + fc2: nn::Linear, + activation: Activation, +} + +impl ChineseClipVisionMlp { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let fc1 = nn::linear( + config.embed_dim(), + config.intermediate_size(), + var.pp("fc1"), + )?; + let fc2 = nn::linear( + config.intermediate_size(), + config.embed_dim(), + var.pp("fc2"), + )?; + + Ok(ChineseClipVisionMlp { + fc1, + fc2, + activation: config.activation(), + }) + } +} + +impl ChineseClipVisionMlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionEncoderLayer { + self_attn: ChineseClipVisionAttention, + layer_norm1: nn::LayerNorm, + mlp: ChineseClipVisionMlp, + layer_norm2: nn::LayerNorm, +} + +impl ChineseClipVisionEncoderLayer { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?; + let layer_norm1 = nn::layer_norm( + config.embed_dim(), + config.layer_norm_eps(), + var.pp("layer_norm1"), + )?; + let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?; + let layer_norm2 = nn::layer_norm( + config.embed_dim(), + config.layer_norm_eps(), + var.pp("layer_norm2"), + )?; + + Ok(ChineseClipVisionEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionEncoder { + layers: Vec, +} + +impl ChineseClipVisionEncoder { + pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let vs = var.pp("layers"); + let mut layers: Vec = Vec::new(); + for index in 0..config.num_hidden_layers() { + let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?; + layers.push(layer) + } + Ok(ChineseClipVisionEncoder { layers }) + } + + pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } + + // required by LLaVA + pub fn output_hidden_states( + &self, + xs: &Tensor, + causal_attention_mask: Option<&Tensor>, + ) -> Result> { + let mut xs = xs.clone(); + let mut hidden_states = Vec::new(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + hidden_states.push(xs.clone()); + } + Ok(hidden_states) + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionTransformer { + embeddings: ChineseClipVisionEmbeddings, + encoder: ChineseClipVisionEncoder, + pre_layer_norm: nn::LayerNorm, + final_layer_norm: nn::LayerNorm, +} + +impl ChineseClipVisionTransformer { + pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result { + let embed_dim = config.hidden_size; + let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?; + let pre_layer_norm = + nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?; + let encoder = ChineseClipVisionEncoder::new( + var.pp("encoder"), + &EncoderConfig::Vision(config.clone()), + )?; + let final_layer_norm = + nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?; + Ok(Self { + embeddings, + encoder, + final_layer_norm, + pre_layer_norm, + }) + } + // required by LLaVA + pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result> { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + + let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; + let encoder_outputs = result.last().unwrap(); + let pooled_output = encoder_outputs.i((.., 0, ..))?; + result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); + Ok(result) + } +} + +impl Module for ChineseClipVisionTransformer { + fn forward(&self, pixel_values: &Tensor) -> Result { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + + let encoder_outputs = self.encoder.forward(&hidden_states, None)?; + + // referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 + let pooled_output = encoder_outputs.i((.., 0, ..))?; + self.final_layer_norm.forward(&pooled_output) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 80cd4f810c..6ed7a8b580 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -5,6 +5,7 @@ pub mod bigcode; pub mod blip; pub mod blip_text; pub mod chatglm; +pub mod chinese_clip; pub mod clip; pub mod codegeex4_9b; pub mod colpali;