Skip to content

Commit

Permalink
attention_mask_for_sdpa support
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Dec 10, 2023
1 parent d9bb95c commit 2fb7de9
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
is_tracing = torch.jit.is_tracing()

if attention_mask is not None:
if torch.all(attention_mask == 1):
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
Expand Down

0 comments on commit 2fb7de9

Please sign in to comment.