Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ColPali #2524

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}

[dev-dependencies]
anyhow = { workspace = true }
Expand Down Expand Up @@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]

[[example]]
name = "colpali"
required-features = ["pdf2image"]
18 changes: 18 additions & 0 deletions candle-examples/examples/colpali/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Colpali

[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)

```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```

```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```
268 changes: 268 additions & 0 deletions candle-examples/examples/colpali/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;

struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}

impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}

fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}

fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}

fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}

fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};

let dummy_prompt: &str = "Describe the image";

let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;

let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;

let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;

let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}

let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());

let top_k_indices = indices[0..self.top_k].to_vec();

Ok(top_k_indices)
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long)]
prompt: String,

/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,

#[arg(long)]
model_id: Option<String>,

#[arg(long, default_value = "main")]
revision: String,

#[arg(long)]
tokenizer_file: Option<String>,

#[arg(long)]
weight_files: Option<String>,

#[arg(long)]
pdf: String,

#[arg(long)]
start: Option<u32>,

#[arg(long)]
end: Option<u32>,
}

fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);

let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));

let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};

let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};

let start = std::time::Instant::now();

let config: paligemma::Config = paligemma::Config::paligemma_3b_448();

println!("retrieved the files in {:?}", start.elapsed());

let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;

let pdf = PDF::from_file(args.pdf)?;

// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};

let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;

println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}
42 changes: 42 additions & 0 deletions candle-transformers/src/models/colpali.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;

use super::paligemma;
use candle_nn::{linear, Linear};

pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}

impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;

Ok(Self {
model,
custom_text_projection,
})
}

pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}

pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}
16 changes: 15 additions & 1 deletion candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}

pub fn forward_embeds(
&mut self,
xs: &Tensor,
Expand All @@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}

// Forward the model and return the hidden states without the lm_head
pub fn forward_embeds_without_projection(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, _, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
Ok(xs)
}

pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
Expand Down
Loading
Loading