Skip to content

Commit

Permalink
Fix dtype (#118)
Browse files Browse the repository at this point in the history
* fix dtype

* add changelog
  • Loading branch information
ManuelFay authored Oct 29, 2024
1 parent 5986a5b commit 831666f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Modified ColQwen and BiQwen to prevent the useless forward pass in the last layer of the original model (classification head)
- Bumped "breaking" dependencies on MTEB and Transformers version and made the corresponding changes in the code
- Added a "num_image_tokens" kwarg to the `ColQwenProcessor` to allow for different image resolutions
- Casted Image dtype in ColPali due to breaking 4.46 transformers update
- Added a "num_image_tokens" kwarg to the `ColQwen2Processor` to allow for different image resolutions

### Fixed

Expand Down
2 changes: 2 additions & 0 deletions colpali_engine/models/paligemma/colpali/modeling_colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, config: PaliGemmaConfig):
def forward(self, *args, **kwargs) -> torch.Tensor:
# Delete output_hidden_states from kwargs
kwargs.pop("output_hidden_states", None)
if "pixel_values" in kwargs:
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)

outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size)
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
Expand Down

0 comments on commit 831666f

Please sign in to comment.