Skip to content

Commit

Permalink
Add ColPali (huggingface#2524)
Browse files Browse the repository at this point in the history
* add colpali

* cleanup

* fix clippy
  • Loading branch information
akshayballal95 authored Oct 1, 2024
1 parent 6110ad8 commit 888d886
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 1 deletion.
5 changes: 5 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}

[dev-dependencies]
anyhow = { workspace = true }
Expand Down Expand Up @@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]

[[example]]
name = "colpali"
required-features = ["pdf2image"]
18 changes: 18 additions & 0 deletions candle-examples/examples/colpali/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Colpali

[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)

```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```

```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```
268 changes: 268 additions & 0 deletions candle-examples/examples/colpali/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;

struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}

impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}

fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}

fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}

fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}

fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};

let dummy_prompt: &str = "Describe the image";

let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;

let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;

let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;

let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}

let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());

let top_k_indices = indices[0..self.top_k].to_vec();

Ok(top_k_indices)
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long)]
prompt: String,

/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,

#[arg(long)]
model_id: Option<String>,

#[arg(long, default_value = "main")]
revision: String,

#[arg(long)]
tokenizer_file: Option<String>,

#[arg(long)]
weight_files: Option<String>,

#[arg(long)]
pdf: String,

#[arg(long)]
start: Option<u32>,

#[arg(long)]
end: Option<u32>,
}

fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);

let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));

let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};

let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};

let start = std::time::Instant::now();

let config: paligemma::Config = paligemma::Config::paligemma_3b_448();

println!("retrieved the files in {:?}", start.elapsed());

let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;

let pdf = PDF::from_file(args.pdf)?;

// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};

let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;

println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}
42 changes: 42 additions & 0 deletions candle-transformers/src/models/colpali.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;

use super::paligemma;
use candle_nn::{linear, Linear};

pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}

impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;

Ok(Self {
model,
custom_text_projection,
})
}

pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}

pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}
16 changes: 15 additions & 1 deletion candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}

pub fn forward_embeds(
&mut self,
xs: &Tensor,
Expand All @@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}

// Forward the model and return the hidden states without the lm_head
pub fn forward_embeds_without_projection(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, _, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
Ok(xs)
}

pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
Expand Down
Loading

0 comments on commit 888d886

Please sign in to comment.