Skip to content

Commit

Permalink
Pix2Struct: fix wrong broadcast axis of attention mask in visual enco…
Browse files Browse the repository at this point in the history
…der (#23976)

* fix wrong broadcast axis of attention mask in visual encoder

* fix slow tests

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
  • Loading branch information
affjljoo3581 and younesbelkada authored Jun 5, 2023
1 parent 7824fa4 commit 44bd590
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def to_projection_shape(states):
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)

if attention_mask.dim() == 2:
position_bias = position_bias + attention_mask[:, None, :, None].to(position_bias.device)
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
else:
# (batch_size, n_heads, seq_length, key_length)
position_bias = position_bias + attention_mask.to(position_bias.device)
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def forward(
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
A picture of a stop sign with a red stop sign on it.
A picture of a stop sign with a red stop sign
```
Training:
Expand All @@ -1719,7 +1719,7 @@ def forward(
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> print(f"{loss.item():.5f}")
5.95566
5.94282
```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down
4 changes: 2 additions & 2 deletions tests/models/pix2struct/test_modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,12 +757,12 @@ def test_batched_inference_image_captioning_conditioned(self):

self.assertEqual(
processor.decode(predictions[0], skip_special_tokens=True),
"A picture of a stop sign with a red stop sign on it.",
"A picture of a stop sign with a red stop sign",
)

self.assertEqual(
processor.decode(predictions[1], skip_special_tokens=True),
"An photography of the Temple Bar and the Temple Bar.",
"An photography of the Temple Bar and other places in the city.",
)

def test_vqa_model(self):
Expand Down

0 comments on commit 44bd590

Please sign in to comment.