From 076521a44d358d21d0ac5fb72f6fb86168f73be3 Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Thu, 23 Jan 2025 15:11:00 -0800 Subject: [PATCH] Forward input keys to decoder. (#944) --- axlearn/common/causal_lm.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 75f4e73c2..63cbb99b4 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -475,22 +475,12 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]: hidden_states: a float Tensor of shape [batch_size, seq_len, hidden_dim] """ self._constrain_input_batch(input_batch) - input_ids: Tensor = input_batch["input_ids"] - token_type_ids: Optional[Tensor] = input_batch.get("token_type_ids") - input_segment_ids: Optional[Tensor] = input_batch.get("input_segment_ids") - input_positions: Optional[Tensor] = input_batch.get("input_positions") + # TODO(markblee): Simplify by using consistent naming between `input_positions` and + # `positions`, `input_segment_ids` and `segment_ids`. # Decoder hidden states: [batch_size, target_len, hidden_dim]. - decoder_output = self.decoder( - # TODO(markblee): Simplify by using consistent naming between `input_positions` and - # `positions`, `input_segment_ids` and `segment_ids`. - input_batch=dict( - input_ids=input_ids, - token_type_ids=token_type_ids, - input_segment_ids=input_segment_ids, - positions=input_positions, - ), - ) - return decoder_output + decoder_batch = {**input_batch} + decoder_batch["positions"] = input_batch.get("input_positions") + return self.decoder(input_batch=decoder_batch) def _metrics( self, input_batch: Nested[Tensor], *, predict_outputs: Nested[Tensor]