@@ -475,22 +475,12 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]:
475
475
hidden_states: a float Tensor of shape [batch_size, seq_len, hidden_dim]
476
476
"""
477
477
self ._constrain_input_batch (input_batch )
478
- input_ids : Tensor = input_batch ["input_ids" ]
479
- token_type_ids : Optional [Tensor ] = input_batch .get ("token_type_ids" )
480
- input_segment_ids : Optional [Tensor ] = input_batch .get ("input_segment_ids" )
481
- input_positions : Optional [Tensor ] = input_batch .get ("input_positions" )
478
+ # TODO(markblee): Simplify by using consistent naming between `input_positions` and
479
+ # `positions`, `input_segment_ids` and `segment_ids`.
482
480
# Decoder hidden states: [batch_size, target_len, hidden_dim].
483
- decoder_output = self .decoder (
484
- # TODO(markblee): Simplify by using consistent naming between `input_positions` and
485
- # `positions`, `input_segment_ids` and `segment_ids`.
486
- input_batch = dict (
487
- input_ids = input_ids ,
488
- token_type_ids = token_type_ids ,
489
- input_segment_ids = input_segment_ids ,
490
- positions = input_positions ,
491
- ),
492
- )
493
- return decoder_output
481
+ decoder_batch = {** input_batch }
482
+ decoder_batch ["positions" ] = input_batch .get ("input_positions" )
483
+ return self .decoder (input_batch = decoder_batch )
494
484
495
485
def _metrics (
496
486
self , input_batch : Nested [Tensor ], * , predict_outputs : Nested [Tensor ]
0 commit comments