Skip to content

Commit bf9a21e

Browse files
committed
make fix-copies
1 parent 2592fb0 commit bf9a21e

File tree

2 files changed

+31
-20
lines changed

2 files changed

+31
-20
lines changed

src/transformers/models/dbrx/modeling_dbrx.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,23 +1247,38 @@ def _update_causal_mask(
12471247
else past_seen_tokens + sequence_length + 1
12481248
)
12491249

1250-
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1251-
if sequence_length != 1:
1252-
causal_mask = torch.triu(causal_mask, diagonal=1)
1253-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1254-
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1255-
if attention_mask is not None:
1256-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1257-
if attention_mask.dim() == 2:
1250+
if attention_mask is not None and attention_mask.dim() == 4:
1251+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1252+
if attention_mask.max() != 0:
1253+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
1254+
causal_mask = attention_mask
1255+
else:
1256+
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
1257+
target_length = self.config.max_position_embeddings
1258+
else: # dynamic cache
1259+
target_length = (
1260+
attention_mask.shape[-1]
1261+
if isinstance(attention_mask, torch.Tensor)
1262+
else past_seen_tokens + sequence_length + 1
1263+
)
1264+
causal_mask = torch.full(
1265+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1266+
)
1267+
if sequence_length != 1:
1268+
causal_mask = torch.triu(causal_mask, diagonal=1)
1269+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1270+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1271+
if attention_mask is not None:
1272+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
12581273
mask_length = attention_mask.shape[-1]
12591274
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
12601275
padding_mask = padding_mask == 0
12611276
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
12621277
padding_mask, min_dtype
12631278
)
12641279
elif attention_mask.dim() == 4:
1265-
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1266-
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1280+
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
1281+
# as the causal mask (i.e. [..., seq_len, full_len])
12671282
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
12681283
logger.warning_once(
12691284
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
@@ -1272,11 +1287,9 @@ def _update_causal_mask(
12721287
offset = cache_position[0]
12731288
else:
12741289
offset = 0
1275-
mask_shape = attention_mask.shape
12761290
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1277-
causal_mask[
1278-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1279-
] = mask_slice
1291+
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
1292+
causal_mask = mask_slice
12801293

12811294
if (
12821295
self.config._attn_implementation == "sdpa"

src/transformers/models/olmo/modeling_olmo.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,8 @@ def _update_causal_mask(
10891089
padding_mask, min_dtype
10901090
)
10911091
elif attention_mask.dim() == 4:
1092-
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
1093-
# cache. In that case, the 4D attention mask attends to the newest tokens only.
1092+
# we can pass both the full 4D mask (i.e. [..., full_len, full_len]) and a 4D mask with the same shape
1093+
# as the causal mask (i.e. [..., seq_len, full_len])
10941094
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
10951095
logger.warning_once(
10961096
"Passing a 4d mask shorter than the input length is deprecated and will be removed in "
@@ -1099,11 +1099,9 @@ def _update_causal_mask(
10991099
offset = cache_position[0]
11001100
else:
11011101
offset = 0
1102-
mask_shape = attention_mask.shape
11031102
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
1104-
causal_mask[
1105-
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
1106-
] = mask_slice
1103+
mask_slice = mask_slice[..., offset : offset + sequence_length, :]
1104+
causal_mask = mask_slice
11071105

11081106
if (
11091107
self.config._attn_implementation == "sdpa"

0 commit comments

Comments
 (0)