@@ -1247,23 +1247,38 @@ def _update_causal_mask(
1247
1247
else past_seen_tokens + sequence_length + 1
1248
1248
)
1249
1249
1250
- causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
1251
- if sequence_length != 1 :
1252
- causal_mask = torch .triu (causal_mask , diagonal = 1 )
1253
- causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1254
- causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
1255
- if attention_mask is not None :
1256
- causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1257
- if attention_mask .dim () == 2 :
1250
+ if attention_mask is not None and attention_mask .dim () == 4 :
1251
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1252
+ if attention_mask .max () != 0 :
1253
+ raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
1254
+ causal_mask = attention_mask
1255
+ else :
1256
+ if hasattr (self .layers [0 ].self_attn , "past_key_value" ): # static cache
1257
+ target_length = self .config .max_position_embeddings
1258
+ else : # dynamic cache
1259
+ target_length = (
1260
+ attention_mask .shape [- 1 ]
1261
+ if isinstance (attention_mask , torch .Tensor )
1262
+ else past_seen_tokens + sequence_length + 1
1263
+ )
1264
+ causal_mask = torch .full (
1265
+ (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
1266
+ )
1267
+ if sequence_length != 1 :
1268
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
1269
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1270
+ causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
1271
+ if attention_mask is not None :
1272
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
1258
1273
mask_length = attention_mask .shape [- 1 ]
1259
1274
padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
1260
1275
padding_mask = padding_mask == 0
1261
1276
causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
1262
1277
padding_mask , min_dtype
1263
1278
)
1264
1279
elif attention_mask .dim () == 4 :
1265
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1266
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
1280
+ # we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
1281
+ # as the causal mask (i.e. [..., seq_len, full_len])
1267
1282
if attention_mask .shape [- 2 ] < cache_position [0 ] + sequence_length :
1268
1283
logger .warning_once (
1269
1284
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
@@ -1272,11 +1287,9 @@ def _update_causal_mask(
1272
1287
offset = cache_position [0 ]
1273
1288
else :
1274
1289
offset = 0
1275
- mask_shape = attention_mask .shape
1276
1290
mask_slice = (attention_mask .eq (0.0 )).to (dtype = dtype ) * min_dtype
1277
- causal_mask [
1278
- : mask_shape [0 ], : mask_shape [1 ], offset : mask_shape [2 ] + offset , : mask_shape [3 ]
1279
- ] = mask_slice
1291
+ mask_slice = mask_slice [..., offset : offset + sequence_length , :]
1292
+ causal_mask = mask_slice
1280
1293
1281
1294
if (
1282
1295
self .config ._attn_implementation == "sdpa"
0 commit comments