From 831666f0a3a97d4eda28fd5d6fd8c5441f7a85ec Mon Sep 17 00:00:00 2001 From: Manuel Faysse <43467008+ManuelFay@users.noreply.github.com> Date: Tue, 29 Oct 2024 17:07:51 +0100 Subject: [PATCH] Fix dtype (#118) * fix dtype * add changelog --- CHANGELOG.md | 3 ++- colpali_engine/models/paligemma/colpali/modeling_colpali.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c756132a..bd436f1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Modified ColQwen and BiQwen to prevent the useless forward pass in the last layer of the original model (classification head) - Bumped "breaking" dependencies on MTEB and Transformers version and made the corresponding changes in the code -- Added a "num_image_tokens" kwarg to the `ColQwenProcessor` to allow for different image resolutions +- Casted Image dtype in ColPali due to breaking 4.46 transformers update +- Added a "num_image_tokens" kwarg to the `ColQwen2Processor` to allow for different image resolutions ### Fixed diff --git a/colpali_engine/models/paligemma/colpali/modeling_colpali.py b/colpali_engine/models/paligemma/colpali/modeling_colpali.py index 7292598b..db3c6ca0 100644 --- a/colpali_engine/models/paligemma/colpali/modeling_colpali.py +++ b/colpali_engine/models/paligemma/colpali/modeling_colpali.py @@ -34,6 +34,8 @@ def __init__(self, config: PaliGemmaConfig): def forward(self, *args, **kwargs) -> torch.Tensor: # Delete output_hidden_states from kwargs kwargs.pop("output_hidden_states", None) + if "pixel_values" in kwargs: + kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size) last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)