File tree Expand file tree Collapse file tree 2 files changed +20
-12
lines changed Expand file tree Collapse file tree 2 files changed +20
-12
lines changed Original file line number Diff line number Diff line change @@ -633,14 +633,18 @@ def call(
633633 )
634634
635635 def serving_output (self , output ):
636- pkv = tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None
637- dec_hs = tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .output_hidden_states else None
638- dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .output_attentions else None
639- enc_hs = tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .output_hidden_states else None
640- enc_attns = tf .convert_to_tensor (output .encoder_attentions ) if self .config .output_attentions else None
636+ pkv = tf .tuple (output .past_key_values )[1 ] if self .config .decoder .use_cache else None
637+ dec_hs = (
638+ tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .decoder .output_hidden_states else None
639+ )
640+ dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .decoder .output_attentions else None
641+ enc_hs = (
642+ tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .encoder .output_hidden_states else None
643+ )
644+ enc_attns = tf .convert_to_tensor (output .encoder_attentions ) if self .config .encoder .output_attentions else None
641645 cross_attns = (
642646 tf .convert_to_tensor (output .cross_attentions )
643- if self .config .output_attentions and output .cross_attentions is not None
647+ if self .config .decoder . output_attentions and output .cross_attentions is not None
644648 else None
645649 )
646650
Original file line number Diff line number Diff line change @@ -662,14 +662,18 @@ def call(
662662 )
663663
664664 def serving_output (self , output ):
665- pkv = tf .tuple (output .past_key_values )[1 ] if self .config .use_cache else None
666- dec_hs = tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .output_hidden_states else None
667- dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .output_attentions else None
668- enc_hs = tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .output_hidden_states else None
669- enc_attns = tf .convert_to_tensor (output .encoder_attentions ) if self .config .output_attentions else None
665+ pkv = tf .tuple (output .past_key_values )[1 ] if self .config .decoder .use_cache else None
666+ dec_hs = (
667+ tf .convert_to_tensor (output .decoder_hidden_states ) if self .config .decoder .output_hidden_states else None
668+ )
669+ dec_attns = tf .convert_to_tensor (output .decoder_attentions ) if self .config .decoder .output_attentions else None
670+ enc_hs = (
671+ tf .convert_to_tensor (output .encoder_hidden_states ) if self .config .encoder .output_hidden_states else None
672+ )
673+ enc_attns = tf .convert_to_tensor (output .encoder_attentions ) if self .config .encoder .output_attentions else None
670674 cross_attns = (
671675 tf .convert_to_tensor (output .cross_attentions )
672- if self .config .output_attentions and output .cross_attentions is not None
676+ if self .config .decoder . output_attentions and output .cross_attentions is not None
673677 else None
674678 )
675679
You can’t perform that action at this time.
0 commit comments