Skip to content

Commit

Permalink
simpler 4d mask shape check
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Nov 20, 2023
1 parent c0e4dc9 commit 53a7e77
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,17 @@ def _prepare_4d_causal_attention_mask(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
if (
attention_mask.shape[0] != input_shape[0]
or attention_mask.shape[1] != 1
or attention_mask.shape[2] != input_shape[1]
or attention_mask.shape[3] != key_value_length
):
raise ValueError(f"Incorrect 4D attention_mask shape: {attention_mask.shape}")
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
# (
# attention_mask.shape[0] != input_shape[0]
# or attention_mask.shape[1] != 1
# or attention_mask.shape[2] != input_shape[1]
# or attention_mask.shape[3] != key_value_length
# ):
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
Expand Down

0 comments on commit 53a7e77

Please sign in to comment.