diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 28a61d590..7f09fccac 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -777,10 +777,15 @@ def forward( if cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] - elif token_idx is not None and past_key_values is not None: - cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) - full_text_row_masked_out_mask = torch.index_select(full_text_row_masked_out_mask, -2, token_idx - 1) - + elif past_key_values is not None: + if token_idx is not None: + cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) + full_text_row_masked_out_mask = torch.index_select( + full_text_row_masked_out_mask, -2, token_idx - 1 + ) + else: + cross_attention_mask = cross_attention_mask[:, :, -1:] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:] outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, @@ -895,7 +900,7 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ - add token_idx handling """ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) - model_kwargs = super()._update_model_kwargs_for_generation( + model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, @@ -909,6 +914,10 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...] cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) model_kwargs["cross_attention_mask"] = cross_attention_mask_prev + else: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) return model_kwargs