diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index da929c0867fcf3..e23bb876c3f037 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1093,7 +1093,11 @@ def prepare_inputs_for_generation( # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: + if ( + isinstance(past_key_values, HybridCache) + and attention_mask.ndim == 2 + and not self.config._attn_implementation == "flash_attention_2" + ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device