Skip to content

Commit 076521a

Browse files
author
Mark Lee
authored
Forward input keys to decoder. (apple#944)
1 parent 30284c8 commit 076521a

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

axlearn/common/causal_lm.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -475,22 +475,12 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]:
475475
hidden_states: a float Tensor of shape [batch_size, seq_len, hidden_dim]
476476
"""
477477
self._constrain_input_batch(input_batch)
478-
input_ids: Tensor = input_batch["input_ids"]
479-
token_type_ids: Optional[Tensor] = input_batch.get("token_type_ids")
480-
input_segment_ids: Optional[Tensor] = input_batch.get("input_segment_ids")
481-
input_positions: Optional[Tensor] = input_batch.get("input_positions")
478+
# TODO(markblee): Simplify by using consistent naming between `input_positions` and
479+
# `positions`, `input_segment_ids` and `segment_ids`.
482480
# Decoder hidden states: [batch_size, target_len, hidden_dim].
483-
decoder_output = self.decoder(
484-
# TODO(markblee): Simplify by using consistent naming between `input_positions` and
485-
# `positions`, `input_segment_ids` and `segment_ids`.
486-
input_batch=dict(
487-
input_ids=input_ids,
488-
token_type_ids=token_type_ids,
489-
input_segment_ids=input_segment_ids,
490-
positions=input_positions,
491-
),
492-
)
493-
return decoder_output
481+
decoder_batch = {**input_batch}
482+
decoder_batch["positions"] = input_batch.get("input_positions")
483+
return self.decoder(input_batch=decoder_batch)
494484

495485
def _metrics(
496486
self, input_batch: Nested[Tensor], *, predict_outputs: Nested[Tensor]

0 commit comments

Comments
 (0)