Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,6 +2099,31 @@ def forward(
depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
)

def _prepare_attention_mask_for_generation(
self,
input_ids: torch.LongTensor,
generation_config: GenerationConfig,
kwargs: Dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id

default_attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
if pad_token_id is None:
return default_attention_mask

is_pad_token_in_inputs = (pad_token_id is not None) and torch.isin(input_ids, pad_token_id).any()
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~torch.isin(
eos_token_id, pad_token_id
).any()
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = input_ids.ne(pad_token_id).long()

attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask

def _prepare_inputs_embeds_for_generation(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -2315,6 +2340,12 @@ def generate(
kwargs_depth_decoder = depth_decoder_generation_config

attention_mask = kwargs.pop("attention_mask", None)
if attention_mask is None:
attention_mask = self._prepare_attention_mask_for_generation(
input_ids=input_ids,
generation_config=generation_config,
kwargs=kwargs,
)
(
inputs_embeds,
input_ids,
Expand Down Expand Up @@ -2497,11 +2528,11 @@ def prepare_inputs_for_generation(
batch_size, sequence_length = input_ids.shape
device = input_ids.device

attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask = self.decoder.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
dtype=self.decoder.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
Expand Down