diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 2c96f87d68..4edde7a966 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -36,6 +36,7 @@ serde_json = { workspace = true } symphonia = { version = "0.5.3", features = ["all"], optional = true } tokenizers = { workspace = true, features = ["onig"] } cpal = { version = "0.15.2", optional = true } +pdf2image = { version = "0.1.2" , optional = true} [dev-dependencies] anyhow = { workspace = true } @@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"] [[example]] name = "silero-vad" required-features = ["onnx"] + +[[example]] +name = "colpali" +required-features = ["pdf2image"] \ No newline at end of file diff --git a/candle-examples/examples/colpali/README.md b/candle-examples/examples/colpali/README.md new file mode 100644 index 0000000000..e6a5579801 --- /dev/null +++ b/candle-examples/examples/colpali/README.md @@ -0,0 +1,18 @@ +# Colpali + +[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged) + +``` +wget https://arxiv.org/pdf/1706.03762.pdf +cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf" +``` + +``` +Prompt: what is position encoding? +top 3 page numbers that contain similarity to the prompt +----------------------------------- +Page: 6 +Page: 11 +Page: 15 +----------------------------------- +``` \ No newline at end of file diff --git a/candle-examples/examples/colpali/main.rs b/candle-examples/examples/colpali/main.rs new file mode 100644 index 0000000000..2a1cc96b9e --- /dev/null +++ b/candle-examples/examples/colpali/main.rs @@ -0,0 +1,268 @@ +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::colpali::Model; +use candle_transformers::models::{colpali, paligemma}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use image::DynamicImage; +use pdf2image::{RenderOptionsBuilder, PDF}; +use tokenizers::Tokenizer; + +struct PageRetriever { + model: Model, + config: paligemma::Config, + pdf: PDF, + device: Device, + tokenizer: Tokenizer, + range: pdf2image::Pages, + batch_size: usize, + top_k: usize, +} + +impl PageRetriever { + fn new( + model: Model, + config: paligemma::Config, + pdf: PDF, + tokenizer: Tokenizer, + device: &Device, + range: Option, + batch_size: usize, + top_k: usize, + ) -> Self { + let page_count = pdf.page_count(); + Self { + model, + config, + pdf, + device: device.clone(), + tokenizer, + range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)), + batch_size, + top_k, + } + } + + fn get_images_from_pdf(&self) -> Result> { + let pages = self + .pdf + .render(self.range.clone(), RenderOptionsBuilder::default().build()?)?; + Ok(pages) + } + + fn tokenize_batch(&self, prompts: Vec<&str>) -> Result { + let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.device) + }) + .collect::>>()?; + let input = Tensor::stack(&token_ids, 0)?; + Ok(input) + } + + fn images_to_tensor( + &self, + pages: &[DynamicImage], + image_size: usize, + ) -> anyhow::Result { + let mut images = vec![]; + for page in pages.iter() { + let img = page.resize_to_fill( + image_size as u32, + image_size as u32, + image::imageops::FilterType::Triangle, + ); + let img = img.to_rgb8(); + let img = img.into_raw(); + let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)? + .permute((2, 0, 1))? + .to_dtype(DType::F32)? + .affine(2. / 255., -1.)?; + images.push(img); + } + let images = Tensor::stack(&images, 0)?; + Ok(images) + } + + fn retrieve(&mut self, prompt: &str) -> Result> { + let dtype = if self.device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + + let dummy_prompt: &str = "Describe the image"; + + let input = self.tokenize_batch(vec![prompt])?; + let dummy_input = self.tokenize_batch(vec![dummy_prompt])?; + + let pages = self.get_images_from_pdf()?; + let mut all_scores = Vec::new(); + for batch in pages.chunks(self.batch_size) { + let page_images = self + .images_to_tensor(batch, self.config.vision_config.image_size)? + .to_device(&self.device)? + .to_dtype(dtype)?; + let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?; + + let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?; + let text_embeddings = self.model.forward_text(&input)?; + + let scores = text_embeddings + .unsqueeze(1)? + .broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)? + .max(3)? + .sum(2)?; + let batch_scores: Vec = scores + .to_dtype(DType::F32)? + .to_vec2()? + .into_iter() + .flatten() + .collect(); + all_scores.extend(batch_scores); + } + + let mut indices: Vec = (0..all_scores.len()).collect(); + indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap()); + + let top_k_indices = indices[0..self.top_k].to_vec(); + + Ok(top_k_indices) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// number of top pages to show. + #[arg(long, default_value_t = 3)] + top_k: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + #[arg(long)] + pdf: String, + + #[arg(long)] + start: Option, + + #[arg(long)] + end: Option, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "vidore/colpali-v1.2-merged".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + .repo(Repo::with_revision( + "vidore/colpali".to_string(), + RepoType::Model, + "main".to_string(), + )) + .get("tokenizer.json")?, + }; + + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let start = std::time::Instant::now(); + + let config: paligemma::Config = paligemma::Config::paligemma_3b_448(); + + println!("retrieved the files in {:?}", start.elapsed()); + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let device = candle_examples::device(false)?; + let dtype = if device.is_cuda() { + DType::BF16 + } else { + DType::F32 + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = colpali::Model::new(&config, vb)?; + + let pdf = PDF::from_file(args.pdf)?; + + // check if start and end given in arg + let range = if let (Some(start), Some(end)) = (args.start, args.end) { + pdf2image::Pages::Range(start..=end) + } else { + pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice. + }; + + let mut retriever = + PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3); + let top_k_indices = retriever.retrieve(&args.prompt)?; + + println!("Prompt: {}", args.prompt); + println!( + "top {} page numbers that contain similarity to the prompt", + retriever.top_k + ); + println!("-----------------------------------"); + for index in top_k_indices { + println!("Page: {:?}", index + 1); + } + println!("-----------------------------------"); + Ok(()) +} diff --git a/candle-transformers/src/models/colpali.rs b/candle-transformers/src/models/colpali.rs new file mode 100644 index 0000000000..1299b0a410 --- /dev/null +++ b/candle-transformers/src/models/colpali.rs @@ -0,0 +1,42 @@ +use candle::{Module, Result, Tensor}; +use candle_nn::VarBuilder; + +use super::paligemma; +use candle_nn::{linear, Linear}; + +pub struct Model { + pub model: paligemma::Model, + pub custom_text_projection: Linear, +} + +impl Model { + pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result { + let model = paligemma::Model::new(config, vb.pp("model"))?; + let custom_text_projection = linear( + config.text_config.hidden_size, + 128, + vb.pp("custom_text_proj"), + )?; + + Ok(Self { + model, + custom_text_projection, + }) + } + + pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result { + let outputs = self + .model + .setup_without_projection(pixel_values, input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } + + pub fn forward_text(&mut self, input_ids: &Tensor) -> Result { + let outputs = self.model.forward_without_projection(input_ids)?; + let outputs = self.custom_text_projection.forward(&outputs)?; + let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?; + Ok(outputs) + } +} diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 69e2267876..c22a39480c 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -403,7 +403,6 @@ impl Model { .apply(&self.norm)? .apply(&self.lm_head) } - pub fn forward_embeds( &mut self, xs: &Tensor, @@ -420,6 +419,21 @@ impl Model { .apply(&self.lm_head) } + // Forward the model and return the hidden states without the lm_head + pub fn forward_embeds_without_projection( + &mut self, + xs: &Tensor, + attn_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (_, _, _) = xs.dims3()?; + let mut xs = (xs * (self.hidden_size as f64).sqrt())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attn_mask, seqlen_offset)? + } + Ok(xs) + } + pub fn clear_kv_cache(&mut self) { for layer in self.layers.iter_mut() { layer.clear_kv_cache() diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 09876503ed..80cd4f810c 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod blip_text; pub mod chatglm; pub mod clip; pub mod codegeex4_9b; +pub mod colpali; pub mod convmixer; pub mod convnext; pub mod dac; diff --git a/candle-transformers/src/models/paligemma.rs b/candle-transformers/src/models/paligemma.rs index e22ab241be..a5e7f694f5 100644 --- a/candle-transformers/src/models/paligemma.rs +++ b/candle-transformers/src/models/paligemma.rs @@ -33,6 +33,29 @@ impl Config { projection_dim: 2048, } } + + pub fn paligemma_3b_448() -> Self { + Self { + vision_config: siglip::VisionConfig::paligemma_3b_448(), + text_config: gemma::Config { + hidden_size: 2048, + intermediate_size: 16384, + num_attention_heads: 8, + num_hidden_layers: 18, + num_key_value_heads: 1, + // Default values. + rope_theta: 10000., + head_dim: 256, + hidden_act: Some(candle_nn::Activation::GeluPytorchTanh), + hidden_activation: None, + attention_bias: false, + max_position_embeddings: 8192, + rms_norm_eps: 1e-6, + vocab_size: 257216, + }, + projection_dim: 2048, + } + } } #[derive(Clone, Debug)] @@ -102,6 +125,28 @@ impl Model { self.language_model.forward(input_ids, pos) } + pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result { + self.clear_kv_cache(); + let input_embeds = self.language_model.embed_tokens().forward(input_ids)?; + self.language_model + .forward_embeds_without_projection(&input_embeds, None, 0) + } + pub fn setup_without_projection( + &mut self, + pixel_values: &Tensor, + input_ids: &Tensor, + ) -> Result { + self.clear_kv_cache(); + let image_features = self + .vision_tower + .forward(pixel_values)? + .apply(&self.multi_modal_projector)?; + let image_features = crate::models::clip::div_l2_norm(&image_features)?; + let text_features = self.language_model.embed_tokens().forward(input_ids)?; + let input_embeds = Tensor::cat(&[image_features, text_features], 1)?; + self.language_model + .forward_embeds_without_projection(&input_embeds, None, 0) + } pub fn clear_kv_cache(&mut self) { self.pos = 0; self.language_model.clear_kv_cache()