Skip to content

Commit e18f233

Browse files
Fix default attention mask of generate in MoshiForConditionalGeneration (#36171)
1 parent 27d1707 commit e18f233

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

src/transformers/models/moshi/modeling_moshi.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,6 +2099,31 @@ def forward(
20992099
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
21002100
)
21012101

2102+
def _prepare_attention_mask_for_generation(
2103+
self,
2104+
input_ids: torch.LongTensor,
2105+
generation_config: GenerationConfig,
2106+
kwargs: Dict[str, Any],
2107+
) -> torch.LongTensor:
2108+
pad_token_id = generation_config.pad_token_id
2109+
eos_token_id = generation_config.eos_token_id
2110+
2111+
default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
2112+
if pad_token_id is None:
2113+
return default_attention_mask
2114+
2115+
is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
2116+
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
2117+
eos_token_id, pad_token_id
2118+
).any()
2119+
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
2120+
attention_mask_from_padding = input_ids.ne(pad_token_id).long()
2121+
2122+
attention_mask = (
2123+
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
2124+
)
2125+
return attention_mask
2126+
21022127
def _prepare_inputs_embeds_for_generation(
21032128
self,
21042129
input_ids: Optional[torch.LongTensor] = None,
@@ -2315,6 +2340,12 @@ def generate(
23152340
kwargs_depth_decoder = depth_decoder_generation_config
23162341

23172342
attention_mask = kwargs.pop("attention_mask", None)
2343+
if attention_mask is None:
2344+
attention_mask = self._prepare_attention_mask_for_generation(
2345+
input_ids=input_ids,
2346+
generation_config=generation_config,
2347+
kwargs=kwargs,
2348+
)
23182349
(
23192350
inputs_embeds,
23202351
input_ids,
@@ -2497,11 +2528,11 @@ def prepare_inputs_for_generation(
24972528
batch_size, sequence_length = input_ids.shape
24982529
device = input_ids.device
24992530

2500-
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
2531+
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
25012532
attention_mask,
25022533
sequence_length=sequence_length,
25032534
target_length=past_key_values.get_max_cache_shape(),
2504-
dtype=self.lm_head.weight.dtype,
2535+
dtype=self.decoder.lm_head.weight.dtype,
25052536
device=device,
25062537
cache_position=cache_position,
25072538
batch_size=batch_size,

0 commit comments

Comments
 (0)