Skip to content

Commit 8691867

Browse files
hiyougaArthurZucker
andcommitted
Fix Gemma2 4d attention mask (#31674)
Update modeling_gemma2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 7edc993 commit 8691867

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,13 @@ def forward(
629629
if (
630630
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
631631
): # efficient SDPA and no padding
632-
attention_mask = attention_mask * torch.tril(
633-
torch.ones_like(attention_mask), diagonal=-self.sliding_window
632+
min_dtype = torch.finfo(hidden_states.dtype).min
633+
sliding_window_mask = torch.tril(
634+
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
634635
)
635-
if attention_mask.shape[1] <= 1: # when decoding
636-
attention_mask = attention_mask[:, -self.sliding_window :]
636+
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
637+
if attention_mask.shape[-1] <= 1: # when decoding
638+
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
637639

638640
residual = hidden_states
639641

0 commit comments

Comments
 (0)