Skip to content

Commit

Permalink
📋 Add eval loss logging during prediction in GRPO (#2694)
Browse files Browse the repository at this point in the history
* add eval loss logging during predition

* make sure the train and eval logs aren't mixed

* test grpo in eval

* fix tests

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
kashif and qgallouedec authored Jan 30, 2025
1 parent ab30a01 commit fecaa99
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_with_eval(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
per_device_eval_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
eval_strategy="steps",
eval_steps=2,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)

trainer.train()

@require_peft
def test_training_peft(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
Expand Down Expand Up @@ -347,6 +371,7 @@ def test_training_vllm(self):
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
)
trainer = GRPOTrainer(
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,21 @@ def get_per_token_logps(model, input_ids, num_logits_to_keep):

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
with torch.no_grad():
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = loss.mean().detach()
return loss, None, None

def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics

# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if next(iter(logs.keys())).startswith("eval_"):
metrics = {f"eval_{key}": val for key, val in metrics.items()}

logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
super().log(logs, start_time)
Expand Down

0 comments on commit fecaa99

Please sign in to comment.