Skip to content

Commit a6752a7

Browse files
authored
Fix serving_output for TF composite models (encoder-decoder like models) (#22743)
* fix * style * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 410b61a commit a6752a7

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,14 +633,18 @@ def call(
633633
)
634634

635635
def serving_output(self, output):
636-
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
637-
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
638-
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
639-
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
640-
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
636+
pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
637+
dec_hs = (
638+
tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
639+
)
640+
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
641+
enc_hs = (
642+
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
643+
)
644+
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
641645
cross_attns = (
642646
tf.convert_to_tensor(output.cross_attentions)
643-
if self.config.output_attentions and output.cross_attentions is not None
647+
if self.config.decoder.output_attentions and output.cross_attentions is not None
644648
else None
645649
)
646650

src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,18 @@ def call(
662662
)
663663

664664
def serving_output(self, output):
665-
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
666-
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
667-
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
668-
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
669-
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
665+
pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
666+
dec_hs = (
667+
tf.convert_to_tensor(output.decoder_hidden_states) if self.config.decoder.output_hidden_states else None
668+
)
669+
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.decoder.output_attentions else None
670+
enc_hs = (
671+
tf.convert_to_tensor(output.encoder_hidden_states) if self.config.encoder.output_hidden_states else None
672+
)
673+
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.encoder.output_attentions else None
670674
cross_attns = (
671675
tf.convert_to_tensor(output.cross_attentions)
672-
if self.config.output_attentions and output.cross_attentions is not None
676+
if self.config.decoder.output_attentions and output.cross_attentions is not None
673677
else None
674678
)
675679

0 commit comments

Comments
 (0)