diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index 7ffc81687d8..267b8f40a5b 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -593,7 +593,7 @@ def _decoder_forward( decoder_input_ids, decoder_attention_mask, decoder_position_ids, - encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states, **kwargs, ) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index faabeae17fa..0326fee63ee 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -627,7 +627,7 @@ def _decoder_forward( decoder_input_ids, decoder_attention_mask, decoder_position_ids, - encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states, **kwargs, )