Skip to content

Commit

Permalink
refine crossattention logic
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Oct 19, 2024
1 parent ca62f24 commit 84e83dd
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions optimum/habana/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,10 +777,15 @@ def forward(
if cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
elif token_idx is not None and past_key_values is not None:
cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1)
full_text_row_masked_out_mask = torch.index_select(full_text_row_masked_out_mask, -2, token_idx - 1)

elif past_key_values is not None:
if token_idx is not None:
cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1)
full_text_row_masked_out_mask = torch.index_select(
full_text_row_masked_out_mask, -2, token_idx - 1
)
else:
cross_attention_mask = cross_attention_mask[:, :, -1:]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:]
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -895,7 +900,7 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_
- add token_idx handling
"""
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
model_kwargs = super()._update_model_kwargs_for_generation(
model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
Expand All @@ -909,6 +914,10 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_
mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...]
cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
else:
model_kwargs["cross_attention_mask"] = torch.cat(
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
)
return model_kwargs


Expand Down

0 comments on commit 84e83dd

Please sign in to comment.