diff --git a/CHANGELOG.md b/CHANGELOG.md index 361117a4f82..f7d97420e0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983)) +- Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001)) + + ### Changed - diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index ae3c59c33d1..72fd179c5c2 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -17,6 +17,7 @@ from torch import Tensor from typing_extensions import Literal +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 @@ -65,6 +66,17 @@ def _clip_score_update( img_features = model.get_image_features(processed_input["pixel_values"].to(device)) img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) + max_position_embeddings = model.config.text_config.max_position_embeddings + if processed_input["attention_mask"].shape[-1] > max_position_embeddings: + rank_zero_warn( + f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length." + "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports" + "longer sequences", + UserWarning, + ) + processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings] + processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings] + txt_features = model.get_text_features( processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device) ) diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 4a8658f1c55..3413187905e 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -127,3 +127,15 @@ def test_plot_method(self, inputs, model_name_or_path): fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + + @skip_on_connection_issues() + def test_warning_on_long_caption(self, inputs, model_name_or_path): + """Test that warning is given on long captions but metric still works.""" + metric = CLIPScore(model_name_or_path=model_name_or_path) + preds, target = inputs + target[0] = [target[0][0], "A 28-year-old chef who recently moved to San Francisco was found dead. " * 100] + with pytest.warns( + UserWarning, + match="Encountered caption longer than max_position_embeddings=77. Will truncate captions to this length.*", + ): + metric.update(preds[0], target[0])