Skip to content

Commit 43ee585

Browse files
authored
Fix MusicGen SDPA (#31208)
* fix sdpa musicgen * make style * remove copied from statement from Musicgen SDPA
1 parent 833fc17 commit 43ee585

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/transformers/models/musicgen/modeling_musicgen.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
545545
)
546546

547547

548-
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->Musicgen
549548
class MusicgenSdpaAttention(MusicgenAttention):
550549
def forward(
551550
self,
@@ -572,6 +571,23 @@ def forward(
572571
output_attentions=output_attentions,
573572
)
574573

574+
if (
575+
attention_mask is not None
576+
and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
577+
):
578+
logger.warning_once(
579+
'`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
580+
"Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
581+
)
582+
return super().forward(
583+
hidden_states,
584+
key_value_states=key_value_states,
585+
past_key_value=past_key_value,
586+
attention_mask=attention_mask,
587+
layer_head_mask=layer_head_mask,
588+
output_attentions=output_attentions,
589+
)
590+
575591
# if key_value_states are provided this layer is used as a cross-attention layer
576592
# for the decoder
577593
is_cross_attention = key_value_states is not None

0 commit comments

Comments
 (0)