Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,8 @@
title: CLVP
- local: model_doc/colpali
title: ColPali
- local: model_doc/colqwen2_5
title: ColQwen2.5
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot
Expand Down
109 changes: 109 additions & 0 deletions docs/source/en/model_doc/colqwen2_5.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

# ColQwen2.5

[ColQwen2](https://doi.org/10.48550/arXiv.2407.01449) is a variant of the [ColPali](./colpali) model designed to retrieve documents by analyzing their visual features. Unlike traditional systems that rely heavily on text extraction and OCR, ColQwen2.5 treats each page as an image. It uses the [Qwen2.5-VL](./qwen2_5_vl) backbone to capture not only text, but also the layout, tables, charts, and other visual elements to create detailed multi-vector embeddings that can be used for retrieval by computing pairwise late interaction similarity scores. This offers a more comprehensive understanding of documents and enables more efficient and accurate retrieval.

This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) (ILLUIN Technology), [@yonigozlan](https://huggingface.co/yonigozlan) (HuggingFace) and [@qnguyen3](https://huggingface.co/qnguyen3) (WARA Media & Language).

You can find all the original ColPali checkpoints under Vidore's [Hf-native ColVision Models](https://huggingface.co/collections/vidore/hf-native-colvision-models-6755d68fc60a8553acaa96f7) collection.

> [!TIP]
> Click on the ColQwen2.5 models in the right sidebar for more examples of how to use ColQwen2.5 for image retrieval.

<hfoptions id="usage">
<hfoption id="image retrieval">

```python
import requests
import torch
from PIL import Image

from transformers import BitsAndBytesConfig, ColQwen2_5ForRetrieval, ColQwen2_5Processor


model_name = "qnguyen3/colqwen2_5-v0.2-hf"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = ColQwen2_5ForRetrieval.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="cuda",
).eval()

processor = ColQwen2_5Processor.from_pretrained(model_name)

url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
url2 = "http://images.cocodataset.org/val2017/000000212573.jpg"

images = [
Image.open(requests.get(url1, stream=True).raw),
Image.open(requests.get(url2, stream=True).raw),
]

queries = [
"WHat are the colors of the two cats?",
"Who printed the edition of Romeo and Juliet?",
]

# Process the inputs
inputs_images = processor(images=images, return_tensors="pt").to(model.device)
inputs_text = processor(text=queries, return_tensors="pt").to(model.device)

# Forward pass
with torch.no_grad():
image_embeddings = model(**inputs_images).embeddings
query_embeddings = model(**inputs_text).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

print("Retrieval scores (query x image):")
print(scores)
```

## Notes

- [`~ColQwen2_5Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.
- Unlike ColPali, ColQwen2.5 supports arbitrary image resolutions and aspect ratios, which means images are not resized into fixed-size squares. This preserves more of the original input signal.
- Larger input images generate longer multi-vector embeddings, allowing users to adjust image resolution to balance performance and memory usage.

## ColQwen2_5Config

[[autodoc]] ColQwen2_5Config

## ColQwen2_5Processor

[[autodoc]] ColQwen2_5Processor

## ColQwen2_5ForRetrieval

[[autodoc]] ColQwen2_5ForRetrieval
- forward
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .cohere import *
from .cohere2 import *
from .colpali import *
from .colqwen2_5 import *
from .conditional_detr import *
from .convbert import *
from .convnext import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
("cohere", "CohereConfig"),
("cohere2", "Cohere2Config"),
("colpali", "ColPaliConfig"),
("colqwen2_5", "ColQwen2_5Config"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
Expand Down Expand Up @@ -432,6 +433,7 @@
("cohere", "Cohere"),
("cohere2", "Cohere2"),
("colpali", "ColPali"),
("colqwen2_5", "ColQwen2_5"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("colpali", "ColPaliForRetrieval"),
("colqwen2_5", "ColQwen2_5ForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2_5", "ColQwen2_5Processor"),
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("colqwen2_5", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/colqwen2_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_colqwen2_5 import *
from .modeling_colqwen2_5 import *
from .processing_colqwen2_5 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
94 changes: 94 additions & 0 deletions src/transformers/models/colqwen2_5/configuration_colqwen2_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from copy import deepcopy
from typing import Any, Dict

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING


logger = logging.get_logger(__name__)


class ColQwen2_5Config(PretrainedConfig):
r"""
Configuration class to store the configuration of a [`ColQwen2_5ForRetrieval`]. It is used to instantiate an instance
of `ColQwen2_5ForRetrieval` according to the specified arguments, defining the model architecture following the methodology
from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.

Instantiating a configuration with the defaults will yield a similar configuration to the vision encoder used by the pre-trained
ColQwen2.5-v1.0 model, e.g. [vidore/colqwen2_5-v1.0-hf](https://huggingface.co/vidore/colqwen2_5-v1.0-hf).

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vlm_config (`PretrainedConfig`, *optional*):
Configuration of the VLM backbone model.
embedding_dim (`int`, *optional*, defaults to 128):
Dimension of the multi-vector embeddings produced by the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:

```python
from transformers.models.colqwen2_5 import ColQwen2_5Config, ColQwen2_5ForRetrieval

config = ColQwen2_5Config()
model = ColQwen2_5ForRetrieval(config)
```
"""

model_type = "colqwen2_5"
sub_configs: Dict[str, Any] = {"vlm_config": PretrainedConfig}

def __init__(
self,
vlm_config=None,
embedding_dim: int = 128,
initializer_range: float = 0.02,
**kwargs,
):
if vlm_config is None:
vlm_config = CONFIG_MAPPING["qwen2_5_vl"]()
logger.info(
"`vlm_config` is `None`. Initializing `vlm_config` with the `Qwen2_5VLConfig` with default values."
)
elif isinstance(vlm_config, dict):
vlm_config = deepcopy(vlm_config)
if "model_type" not in vlm_config:
raise KeyError(
"The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
)
vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
elif isinstance(vlm_config, PretrainedConfig):
vlm_config = vlm_config
else:
raise TypeError(
f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
)

self.vlm_config = vlm_config
self.embedding_dim = embedding_dim
self.initializer_range = initializer_range
super().__init__(**kwargs)

def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.vlm_config.get_text_config(decoder=decoder)


__all__ = ["ColQwen2_5Config"]
Loading