diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 6c43aac766..c5b80620f8 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -89,8 +89,8 @@ class GRPOConfig(TrainingArguments): [`~transformers.TrainingArguments`]. beta (`float`, *optional*, defaults to `0.04`): KL coefficient. - num_updates (`int`, *optional*, defaults to `1`): - Number of updates per batch. + num_iterations (`int`, *optional*, defaults to `1`): + Number of iterations per batch (denoted as μ in the algorithm). epsilon (`float`, *optional*, defaults to `0.2`): Epsilon value for clipping reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): @@ -224,9 +224,9 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) - num_updates: int = field( + num_iterations: int = field( default=1, - metadata={"help": "Number of updates per batch."}, + metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, ) epsilon: float = field( default=0.2, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f20d1a5415..c5f3db7c4b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -90,14 +90,19 @@ class RepeatRandomSampler(Sampler): ``` ```txt - mini_repeat_count - - - - [4, 4, 3, 3, 0, 0, | - 4, 4, 3, 3, 0, 0, | - 4, 4, 3, 3, 0, 0, | repeat_count - 4, 4, 3, 3, 0, 0] | - ---- ---- ---- - batch_size + mini_repeat_count = 3 + - - - + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | + repeat_count = 2 + 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | + --------- --------- --------- --------- + --------- --------- --------- --------- + --------- --------- --------- --------- + batch_size = 12 ``` """ @@ -348,10 +353,16 @@ def data_collator(features): # No data collation is needed in GRPO self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.use_vllm = args.use_vllm - self.num_updates = args.num_updates - self.epsilon = args.epsilon self.beta = args.beta - self._update_remaning = 0 + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon = args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = [None] * args.gradient_accumulation_steps # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the @@ -509,28 +520,48 @@ def _set_signature_columns_if_needed(self): self._signature_columns = ["prompt"] def _get_train_sampler(self) -> Sampler: - # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that - # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly - # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, - # preventing discrepancies in group formation. + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. + # 2. repeats the batch multiple times to allow reusing generaations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # | GPU 0 | GPU 1 | GPU 2 | + # + # global_step step <───────> num_generations=3 + # <───────────> per_device_train_batch_size=4 + # 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │ + # 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ gradient_accumulation=3 + # 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │ + # + # 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ num_iterations=2: + # 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ reuse the batch once + # 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │ + # + # 2 6 12 12 12 13 13 13 14 14 14 15 15 15 + # 2 7 16 16 16 17 17 17 18 18 18 19 19 19 + # 2 8 20 20 20 21 21 21 22 22 22 23 23 23 + # ... + effective_batch_size = ( + self.args.per_device_train_batch_size + * self.accelerator.num_processes + * self.args.gradient_accumulation_steps + ) return RepeatRandomSampler( data_source=self.train_dataset, mini_repeat_count=self.num_generations, - batch_size=self.args.per_device_train_batch_size * self.num_generations // self.accelerator.num_processes, - repeat_count=self.num_updates, + batch_size=effective_batch_size // self.num_generations, + repeat_count=self.num_iterations, seed=self.args.seed, ) def _get_eval_sampler(self, eval_dataset) -> Sampler: - # Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that - # identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly - # within each prompt group. Using the same seed across processes ensures consistent prompt assignment, - # preventing discrepancies in group formation. + # See _get_train_sampler for an explanation of the sampler. return RepeatRandomSampler( data_source=eval_dataset, mini_repeat_count=self.num_generations, - batch_size=self.args.per_device_eval_batch_size * self.accelerator.num_processes, - repeat_count=self.num_updates, seed=self.args.seed, ) @@ -578,13 +609,12 @@ def _move_model_to_vllm(self): unwrapped_model.unmerge_adapter() def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: - if self._update_remaning == 0: + if self.state.global_step % self.num_iterations == 0: inputs = self._generate_and_score_completions(inputs) - self._buffered_inputs = inputs - self._update_remaning = self.num_updates + self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs else: - inputs = self._buffered_inputs - self._update_remaning -= 1 + inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] + self._step += 1 return inputs def _generate_and_score_completions(