Skip to content

Commit 7edc993

Browse files
winglianArthurZucker
authored andcommitted
don't zero out the attention_mask when using sliding window with flash attention (#31670)
* don't zero out the attention_mask when using sliding window with flash attention * chore: lint
1 parent e3cb841 commit 7edc993

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ def forward(
602602
class Gemma2DecoderLayer(nn.Module):
603603
def __init__(self, config: Gemma2Config, layer_idx: int):
604604
super().__init__()
605+
self.config = config
605606
self.hidden_size = config.hidden_size
606607

607608
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
@@ -625,7 +626,9 @@ def forward(
625626
use_cache: Optional[bool] = False,
626627
cache_position: Optional[torch.LongTensor] = None,
627628
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
628-
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
629+
if (
630+
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
631+
): # efficient SDPA and no padding
629632
attention_mask = attention_mask * torch.tril(
630633
torch.ones_like(attention_mask), diagonal=-self.sliding_window
631634
)

0 commit comments

Comments
 (0)