From 44bd590a2926af8f3921a997ed6543de8e2d018a Mon Sep 17 00:00:00 2001 From: Jungwoo Park Date: Tue, 6 Jun 2023 00:47:29 +0900 Subject: [PATCH] Pix2Struct: fix wrong broadcast axis of attention mask in visual encoder (#23976) * fix wrong broadcast axis of attention mask in visual encoder * fix slow tests --------- Co-authored-by: younesbelkada --- src/transformers/models/pix2struct/modeling_pix2struct.py | 6 +++--- tests/models/pix2struct/test_modeling_pix2struct.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 1a900e264c023f..539c95eda0c7b3 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -210,7 +210,7 @@ def to_projection_shape(states): attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) if attention_mask.dim() == 2: - position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device) + position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) else: # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + attention_mask.to(position_bias.device) @@ -1695,7 +1695,7 @@ def forward( >>> generated_ids = model.generate(**inputs, max_new_tokens=50) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> print(generated_text) - A picture of a stop sign with a red stop sign on it. + A picture of a stop sign with a red stop sign ``` Training: @@ -1719,7 +1719,7 @@ def forward( >>> outputs = model(**inputs, labels=labels) >>> loss = outputs.loss >>> print(f"{loss.item():.5f}") - 5.95566 + 5.94282 ```""" use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 4dbd7e649f16e3..4d028d111d0f14 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -757,12 +757,12 @@ def test_batched_inference_image_captioning_conditioned(self): self.assertEqual( processor.decode(predictions[0], skip_special_tokens=True), - "A picture of a stop sign with a red stop sign on it.", + "A picture of a stop sign with a red stop sign", ) self.assertEqual( processor.decode(predictions[1], skip_special_tokens=True), - "An photography of the Temple Bar and the Temple Bar.", + "An photography of the Temple Bar and other places in the city.", ) def test_vqa_model(self):