@@ -371,9 +371,7 @@ def forward(
371
371
attn_weights = torch .matmul (query_states , key_states .transpose (2 , 3 )) / math .sqrt (self .head_dim )
372
372
373
373
if attention_mask is not None : # no matter the length, we just slice it
374
- causal_mask = attention_mask
375
- if cache_position is not None :
376
- causal_mask = attention_mask [:, :, cache_position , : key_states .shape [- 2 ]]
374
+ causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
377
375
attn_weights = attn_weights + causal_mask
378
376
379
377
# upcast attention to fp32
@@ -658,8 +656,9 @@ def forward(
658
656
value_states = repeat_kv (value_states , self .num_key_value_groups )
659
657
660
658
causal_mask = attention_mask
661
- if attention_mask is not None and cache_position is not None :
662
- causal_mask = causal_mask [:, :, cache_position , : key_states .shape [- 2 ]]
659
+ # if attention_mask is not None and cache_position is not None:
660
+ if attention_mask is not None :
661
+ causal_mask = causal_mask [:, :, :, : key_states .shape [- 2 ]]
663
662
664
663
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
665
664
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -792,7 +791,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
792
791
base_model_prefix = "model"
793
792
supports_gradient_checkpointing = True
794
793
_no_split_modules = ["LlamaDecoderLayer" ]
795
- _skip_keys_device_placement = ["past_key_values" , "causal_mask" ]
794
+ _skip_keys_device_placement = ["past_key_values" ]
796
795
_supports_flash_attn_2 = True
797
796
_supports_sdpa = True
798
797
_supports_cache_class = True
@@ -815,12 +814,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
815
814
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
816
815
)
817
816
818
- if max_cache_len > self .model .causal_mask .shape [- 1 ] or self .device != self .model .causal_mask .device :
819
- causal_mask = torch .full (
820
- (max_cache_len , max_cache_len ), fill_value = True , device = self .device , dtype = torch .bool
821
- )
822
- self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
823
-
824
817
for layer in self .model .layers :
825
818
device = layer .input_layernorm .weight .device
826
819
if hasattr (self .config , "_pre_quantization_dtype" ):
@@ -934,12 +927,6 @@ def __init__(self, config: LlamaConfig):
934
927
self .norm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
935
928
self .gradient_checkpointing = False
936
929
937
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
938
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
939
- causal_mask = torch .full (
940
- (config .max_position_embeddings , config .max_position_embeddings ), fill_value = True , dtype = torch .bool
941
- )
942
- self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
943
930
# Initialize weights and apply final processing
944
931
self .post_init ()
945
932
@@ -1000,7 +987,7 @@ def forward(
1000
987
if position_ids is None :
1001
988
position_ids = cache_position .unsqueeze (0 )
1002
989
1003
- causal_mask = self ._update_causal_mask (attention_mask , inputs_embeds , past_seen_tokens )
990
+ causal_mask = self ._update_causal_mask (attention_mask , inputs_embeds , cache_position )
1004
991
1005
992
# embed positions
1006
993
hidden_states = inputs_embeds
@@ -1068,25 +1055,27 @@ def forward(
1068
1055
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1069
1056
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1070
1057
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1071
- def _update_causal_mask (self , attention_mask , input_tensor , past_seen_tokens ):
1058
+ def _update_causal_mask (self , attention_mask , input_tensor , cache_position ):
1072
1059
if self .config ._attn_implementation == "flash_attention_2" :
1073
1060
if attention_mask is not None and 0.0 in attention_mask :
1074
1061
return attention_mask
1075
1062
return None
1076
1063
1077
- batch_size , seq_length = input_tensor .shape [:2 ]
1078
- dtype = input_tensor .dtype
1079
- device = input_tensor .device
1080
-
1081
- # support going beyond cached `max_position_embedding`
1082
- if seq_length > self .causal_mask .shape [- 1 ]:
1083
- causal_mask = torch .full ((2 * self .causal_mask .shape [- 1 ], 2 * self .causal_mask .shape [- 1 ]), fill_value = 1 )
1084
- self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
1085
-
1086
- # We use the current dtype to avoid any overflows
1064
+ dtype , device = input_tensor .dtype , input_tensor .device
1087
1065
min_dtype = torch .finfo (dtype ).min
1088
- causal_mask = self .causal_mask [None , None , :, :].to (dtype = dtype , device = device ) * min_dtype
1089
- causal_mask = causal_mask .expand (batch_size , 1 , - 1 , - 1 )
1066
+ sequence_length = input_tensor .shape [1 ]
1067
+ if hasattr (self .layers [0 ].self_attn , "past_key_value" ): # static cache
1068
+ target_length = self .config .max_position_embeddings
1069
+ else : # dynamic cache
1070
+ target_length = (
1071
+ attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else cache_position [- 1 ] + 1
1072
+ )
1073
+
1074
+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
1075
+ if sequence_length != 1 :
1076
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
1077
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1078
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
1090
1079
if attention_mask is not None :
1091
1080
causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1092
1081
if attention_mask .dim () == 2 :
@@ -1096,8 +1085,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, past_seen_tokens):
1096
1085
elif attention_mask .dim () == 4 :
1097
1086
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1098
1087
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1099
- if attention_mask .shape [- 2 ] < past_seen_tokens + input_tensor . shape [ 1 ] :
1100
- offset = past_seen_tokens
1088
+ if attention_mask .shape [- 2 ] < cache_position [ 0 ] + sequence_length :
1089
+ offset = cache_position [ 0 ]
1101
1090
else :
1102
1091
offset = 0
1103
1092
mask_shape = attention_mask .shape
0 commit comments