Skip to content

Commit c72ba69

Browse files
tonywu71yonigozlan
andauthored
Add ColQwen2 to 🤗 transformers (#35778)
* feat: add colqwen2 (wip) * tests: fix test_attention_outputs * tests: reduce hidden size to accelerate tests * tests: fix `test_attention_outputs` 🥳 * fix: fix wrong parent class for `ColQwen2ForRetrievalOutput` * fix: minor typing and style changes * chore: run `make style` * feat: remove redundant `max_num_visual_tokens` attribute in `ColQwen2Processor` * tests: tweak comments * style: apply ruff formatter * feat: move default values for `visual_prompt_prefix` and `query_prefix` * docs: update ColQwen2 model card * docs: tweak model cards * docs: add required example config checkpoint * tests: update expected scores in integration test * docs: tweak quickstart snippets * fix: address PR comments * tests: fix colqwen2 tests + tweak comment in colpali test * tests: unskip useful tests * fix: fix bug when `visual_prompt_prefix` or `query_prefix` is an empty string * fix: fix ColPali outputs when `return_dict == False` * fix: fix issue with PaliGemma output not being a dict * docs: set default dtype to bfloat16 in quickstart snippets * fix: fix error when `return_dict=False` in ColPali and ColQwen2 * tests: fix special tokens not being replaced in input_ids * style: fix lint * fix: `ColQwen2Processor`'s `padding_side` is now set from `processor_config.json` * fix: remove unused `padding_side` in ColQwen2 model * docs: update ColQwen2's model doc * fix: fix harcoded vlm backbone class in ColQwen2Config * fix: remove `padding_side` from ColQwen2Processor as should fed from kwargs * docs: fix typo in model docstring * docs: add illuin mention in model docs * fix: let `padding_size` be handled by `tokenizer_config.json` * docs: add colpali reference url in colqwen2's model doc * docs: add Hf mention in model docs * docs: add late interaction mention in model docs * docs: tweak colqwen2 model doc * docs: update reference checkpoint for ColPali to v1.3 * docs: simplify quickstart snippets * docs: remove redundant `.eval()` * refactor: use `can_return_tuple` decorator for ColPali and ColQwen2 * docs: fix copyright date * docs: add missing copyright in tests * fix: raise error when `initializer_range` is not in config * docs: remove redundant `.eval()` in colpali doc * fix: fix `get_text_config` now that Qwen2VL has a proper `text_config` attribute See #37268 for details about changes in Qwen2VL's config. * fix: add missing `initializer_range` attribute in `ColQwen2Config` * fix: use `get_text_config` in `resize_token_embeddings` * update colwen2 with auto_docstring * docs: fix wrong copyright year * chore: remove `raise` as `initializer_range` has a default value in `ColQwen2Config` * refactor: merge `inner_forward` into `forward` * Refactor colqwen2 after refactoring of qwen2VL, use modular for modeling code * protect torch import in modular to protect in processing * protect torch import in modular to protect in processing * tests: fix hf model path in ColQwen2 integration test * docs: clarify `attn_implementation` and add comments * docs: add fallback snippet for using offline PIL dummy images * docs: temporarily revert attn_implementation to `None` while sdpa is not fixed * docs: tweaks in colpali/colqwen2 quick start snippets * fix: add missing flags to enable SDPA/Flex Attention in ColQwen2 model * fix: add missing changes in modular file * fix modeling tests --------- Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
1 parent beaed8c commit c72ba69

23 files changed

+2288
-94
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,8 @@
937937
title: CLVP
938938
- local: model_doc/colpali
939939
title: ColPali
940+
- local: model_doc/colqwen2
941+
title: ColQwen2
940942
- local: model_doc/data2vec
941943
title: Data2Vec
942944
- local: model_doc/deplot

docs/source/en/model_doc/colpali.md

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,37 @@ rendered properly in your Markdown viewer.
2020

2121
# ColPali
2222

23-
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed embeddings. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
23+
[ColPali](https://huggingface.co/papers/2407.01449) is a model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColPali treats each page as an image. It uses [Paligemma-3B](./paligemma) to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
2424

25-
You can find all the original ColPali checkpoints under the [ColPali](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
25+
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
26+
27+
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
2628

2729
> [!TIP]
2830
> Click on the ColPali models in the right sidebar for more examples of how to use ColPali for image retrieval.
2931
3032
<hfoptions id="usage">
3133
<hfoption id="image retrieval">
3234

33-
```py
35+
```python
3436
import requests
3537
import torch
3638
from PIL import Image
39+
3740
from transformers import ColPaliForRetrieval, ColPaliProcessor
3841

39-
# Load model (bfloat16 support is limited; fallback to float32 if needed)
42+
43+
# Load the model and the processor
44+
model_name = "vidore/colpali-v1.3-hf"
45+
4046
model = ColPaliForRetrieval.from_pretrained(
41-
"vidore/colpali-v1.2-hf",
42-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
47+
model_name,
48+
torch_dtype=torch.bfloat16,
4349
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
44-
).eval()
45-
50+
)
4651
processor = ColPaliProcessor.from_pretrained(model_name)
4752

53+
# The document page screenshots from your corpus
4854
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
4955
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
5056

@@ -53,38 +59,53 @@ images = [
5359
Image.open(requests.get(url2, stream=True).raw),
5460
]
5561

62+
# The queries you want to retrieve documents for
5663
queries = [
57-
"Who printed the edition of Romeo and Juliet?",
5864
"When was the United States Declaration of Independence proclaimed?",
65+
"Who printed the edition of Romeo and Juliet?",
5966
]
6067

6168
# Process the inputs
62-
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
63-
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
69+
inputs_images = processor(images=images).to(model.device)
70+
inputs_text = processor(text=queries).to(model.device)
6471

6572
# Forward pass
6673
with torch.no_grad():
6774
image_embeddings = model(**inputs_images).embeddings
6875
query_embeddings = model(**inputs_text).embeddings
6976

77+
# Score the queries against the images
7078
scores = processor.score_retrieval(query_embeddings, image_embeddings)
7179

7280
print("Retrieval scores (query x image):")
7381
print(scores)
7482
```
83+
84+
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
85+
86+
```python
87+
images = [
88+
Image.new("RGB", (128, 128), color="white"),
89+
Image.new("RGB", (64, 32), color="black"),
90+
]
91+
```
92+
7593
</hfoption>
7694
</hfoptions>
7795

7896
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
7997

8098
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
8199

82-
```py
100+
```python
83101
import requests
84102
import torch
85103
from PIL import Image
86-
from transformers import ColPaliForRetrieval, ColPaliProcessor
87-
from transformers import BitsAndBytesConfig
104+
105+
from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor
106+
107+
108+
model_name = "vidore/colpali-v1.3-hf"
88109

89110
# 4-bit quantization configuration
90111
bnb_config = BitsAndBytesConfig(
@@ -94,14 +115,11 @@ bnb_config = BitsAndBytesConfig(
94115
bnb_4bit_compute_dtype=torch.float16,
95116
)
96117

97-
model_name = "vidore/colpali-v1.2-hf"
98-
99-
# Load model
100118
model = ColPaliForRetrieval.from_pretrained(
101119
model_name,
102120
quantization_config=bnb_config,
103-
device_map="cuda"
104-
).eval()
121+
device_map="cuda",
122+
)
105123

106124
processor = ColPaliProcessor.from_pretrained(model_name)
107125

@@ -114,8 +132,8 @@ images = [
114132
]
115133

116134
queries = [
117-
"Who printed the edition of Romeo and Juliet?",
118135
"When was the United States Declaration of Independence proclaimed?",
136+
"Who printed the edition of Romeo and Juliet?",
119137
]
120138

121139
# Process the inputs
@@ -127,6 +145,7 @@ with torch.no_grad():
127145
image_embeddings = model(**inputs_images).embeddings
128146
query_embeddings = model(**inputs_text).embeddings
129147

148+
# Score the queries against the images
130149
scores = processor.score_retrieval(query_embeddings, image_embeddings)
131150

132151
print("Retrieval scores (query x image):")
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
</div>
21+
</div>
22+
23+
# ColQwen2
24+
25+
[ColQwen2](https://doi.org/10.48550/arXiv.2407.01449) is a variant of the [ColPali](./colpali) model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2 treats each page as an image. It uses the [Qwen2-VL](./qwen2_vl) backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.
26+
27+
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology) and [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace).
28+
29+
You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.
30+
31+
> [!TIP]
32+
> Click on the ColQwen2 models in the right sidebar for more examples of how to use ColQwen2 for image retrieval.
33+
34+
<hfoptions id="usage">
35+
<hfoption id="image retrieval">
36+
37+
```python
38+
import requests
39+
import torch
40+
from PIL import Image
41+
42+
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
43+
from transformers.utils.import_utils import is_flash_attn_2_available
44+
45+
46+
# Load the model and the processor
47+
model_name = "vidore/colqwen2-v1.0-hf"
48+
49+
model = ColQwen2ForRetrieval.from_pretrained(
50+
model_name,
51+
torch_dtype=torch.bfloat16,
52+
device_map="auto", # "cpu", "cuda", or "mps" for Apple Silicon
53+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa",
54+
)
55+
processor = ColQwen2Processor.from_pretrained(model_name)
56+
57+
# The document page screenshots from your corpus
58+
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
59+
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
60+
61+
images = [
62+
Image.open(requests.get(url1, stream=True).raw),
63+
Image.open(requests.get(url2, stream=True).raw),
64+
]
65+
66+
# The queries you want to retrieve documents for
67+
queries = [
68+
"When was the United States Declaration of Independence proclaimed?",
69+
"Who printed the edition of Romeo and Juliet?",
70+
]
71+
72+
# Process the inputs
73+
inputs_images = processor(images=images).to(model.device)
74+
inputs_text = processor(text=queries).to(model.device)
75+
76+
# Forward pass
77+
with torch.no_grad():
78+
image_embeddings = model(**inputs_images).embeddings
79+
query_embeddings = model(**inputs_text).embeddings
80+
81+
# Score the queries against the images
82+
scores = processor.score_retrieval(query_embeddings, image_embeddings)
83+
84+
print("Retrieval scores (query x image):")
85+
print(scores)
86+
```
87+
88+
If you have issue with loading the images with PIL, you can use the following code to create dummy images:
89+
90+
```python
91+
images = [
92+
Image.new("RGB", (128, 128), color="white"),
93+
Image.new("RGB", (64, 32), color="black"),
94+
]
95+
```
96+
97+
</hfoption>
98+
</hfoptions>
99+
100+
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
101+
102+
The example below uses [bitsandbytes](../quantization/bitsandbytes.md) to quantize the weights to int4.
103+
104+
```python
105+
import requests
106+
import torch
107+
from PIL import Image
108+
109+
from transformers import BitsAndBytesConfig, ColQwen2ForRetrieval, ColQwen2Processor
110+
111+
112+
model_name = "vidore/colqwen2-v1.0-hf"
113+
114+
# 4-bit quantization configuration
115+
bnb_config = BitsAndBytesConfig(
116+
load_in_4bit=True,
117+
bnb_4bit_use_double_quant=True,
118+
bnb_4bit_quant_type="nf4",
119+
bnb_4bit_compute_dtype=torch.float16,
120+
)
121+
122+
model = ColQwen2ForRetrieval.from_pretrained(
123+
model_name,
124+
quantization_config=bnb_config,
125+
device_map="cuda",
126+
).eval()
127+
128+
processor = ColQwen2Processor.from_pretrained(model_name)
129+
130+
url1 = "https://upload.wikimedia.org/wikipedia/commons/8/89/US-original-Declaration-1776.jpg"
131+
url2 = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4c/Romeoandjuliet1597.jpg/500px-Romeoandjuliet1597.jpg"
132+
133+
images = [
134+
Image.open(requests.get(url1, stream=True).raw),
135+
Image.open(requests.get(url2, stream=True).raw),
136+
]
137+
138+
queries = [
139+
"When was the United States Declaration of Independence proclaimed?",
140+
"Who printed the edition of Romeo and Juliet?",
141+
]
142+
143+
# Process the inputs
144+
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
145+
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)
146+
147+
# Forward pass
148+
with torch.no_grad():
149+
image_embeddings = model(**inputs_images).embeddings
150+
query_embeddings = model(**inputs_text).embeddings
151+
152+
# Score the queries against the images
153+
scores = processor.score_retrieval(query_embeddings, image_embeddings)
154+
155+
print("Retrieval scores (query x image):")
156+
print(scores)
157+
```
158+
159+
## Notes
160+
161+
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
162+
- Unlike ColPali, ColQwen2 supports arbitrary image resolutions and aspect ratios, which means images are not resized into fixed-size squares. This preserves more of the original input signal.
163+
- Larger input images generate longer multi-vector embeddings, allowing users to adjust image resolution to balance performance and memory usage.
164+
165+
## ColQwen2Config
166+
167+
[[autodoc]] ColQwen2Config
168+
169+
## ColQwen2Processor
170+
171+
[[autodoc]] ColQwen2Processor
172+
173+
## ColQwen2ForRetrieval
174+
175+
[[autodoc]] ColQwen2ForRetrieval
176+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .cohere import *
6363
from .cohere2 import *
6464
from .colpali import *
65+
from .colqwen2 import *
6566
from .conditional_detr import *
6667
from .convbert import *
6768
from .convnext import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
("cohere", "CohereConfig"),
8080
("cohere2", "Cohere2Config"),
8181
("colpali", "ColPaliConfig"),
82+
("colqwen2", "ColQwen2Config"),
8283
("conditional_detr", "ConditionalDetrConfig"),
8384
("convbert", "ConvBertConfig"),
8485
("convnext", "ConvNextConfig"),
@@ -437,6 +438,7 @@
437438
("cohere", "Cohere"),
438439
("cohere2", "Cohere2"),
439440
("colpali", "ColPali"),
441+
("colqwen2", "ColQwen2"),
440442
("conditional_detr", "Conditional DETR"),
441443
("convbert", "ConvBERT"),
442444
("convnext", "ConvNeXT"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@
365365
("bloom", "BloomForCausalLM"),
366366
("camembert", "CamembertForMaskedLM"),
367367
("colpali", "ColPaliForRetrieval"),
368+
("colqwen2", "ColQwen2ForRetrieval"),
368369
("ctrl", "CTRLLMHeadModel"),
369370
("data2vec-text", "Data2VecTextForMaskedLM"),
370371
("deberta", "DebertaForMaskedLM"),

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
("clipseg", "CLIPSegProcessor"),
6767
("clvp", "ClvpProcessor"),
6868
("colpali", "ColPaliProcessor"),
69+
("colqwen2", "ColQwen2Processor"),
6970
("emu3", "Emu3Processor"),
7071
("flava", "FlavaProcessor"),
7172
("fuyu", "FuyuProcessor"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
148148
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
149149
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
150+
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
150151
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
151152
(
152153
"cpm",

src/transformers/models/colpali/configuration_colpali.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ class ColPaliConfig(PretrainedConfig):
3333
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
3434
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
3535
36-
The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
37-
3836
Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
3937
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
4038
@@ -93,7 +91,7 @@ def __init__(
9391
)
9492

9593
self.vlm_config = vlm_config
96-
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
94+
self.text_config = text_config if text_config is not None else vlm_config.text_config
9795
if isinstance(self.text_config, dict):
9896
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
9997
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)

0 commit comments

Comments
 (0)