File tree 1 file changed +4
-1
lines changed
src/transformers/models/gemma2
1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -602,6 +602,7 @@ def forward(
602
602
class Gemma2DecoderLayer (nn .Module ):
603
603
def __init__ (self , config : Gemma2Config , layer_idx : int ):
604
604
super ().__init__ ()
605
+ self .config = config
605
606
self .hidden_size = config .hidden_size
606
607
607
608
self .self_attn = GEMMA2_ATTENTION_CLASSES [config ._attn_implementation ](config = config , layer_idx = layer_idx )
@@ -625,7 +626,9 @@ def forward(
625
626
use_cache : Optional [bool ] = False ,
626
627
cache_position : Optional [torch .LongTensor ] = None ,
627
628
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
628
- if self .is_sliding and attention_mask is not None : # efficient SDPA and no padding
629
+ if (
630
+ self .config ._attn_implementation != "flash_attention_2" and self .is_sliding and attention_mask is not None
631
+ ): # efficient SDPA and no padding
629
632
attention_mask = attention_mask * torch .tril (
630
633
torch .ones_like (attention_mask ), diagonal = - self .sliding_window
631
634
)
You can’t perform that action at this time.
0 commit comments