Skip to content

Commit

Permalink
Forward input keys to decoder. (apple#944)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Jan 23, 2025
1 parent 30284c8 commit 076521a
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,22 +475,12 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]:
hidden_states: a float Tensor of shape [batch_size, seq_len, hidden_dim]
"""
self._constrain_input_batch(input_batch)
input_ids: Tensor = input_batch["input_ids"]
token_type_ids: Optional[Tensor] = input_batch.get("token_type_ids")
input_segment_ids: Optional[Tensor] = input_batch.get("input_segment_ids")
input_positions: Optional[Tensor] = input_batch.get("input_positions")
# TODO(markblee): Simplify by using consistent naming between `input_positions` and
# `positions`, `input_segment_ids` and `segment_ids`.
# Decoder hidden states: [batch_size, target_len, hidden_dim].
decoder_output = self.decoder(
# TODO(markblee): Simplify by using consistent naming between `input_positions` and
# `positions`, `input_segment_ids` and `segment_ids`.
input_batch=dict(
input_ids=input_ids,
token_type_ids=token_type_ids,
input_segment_ids=input_segment_ids,
positions=input_positions,
),
)
return decoder_output
decoder_batch = {**input_batch}
decoder_batch["positions"] = input_batch.get("input_positions")
return self.decoder(input_batch=decoder_batch)

def _metrics(
self, input_batch: Nested[Tensor], *, predict_outputs: Nested[Tensor]
Expand Down

0 comments on commit 076521a

Please sign in to comment.