Skip to content

Commit

Permalink
buffer inputs with grad accum
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Feb 20, 2025
1 parent a178fd9 commit 946d1f6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
8 changes: 4 additions & 4 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class GRPOConfig(TrainingArguments):
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
num_updates (`int`, *optional*, defaults to `1`):
Number of updates per batch.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -224,9 +224,9 @@ class GRPOConfig(TrainingArguments):
default=0.04,
metadata={"help": "KL coefficient."},
)
num_updates: int = field(
num_iterations: int = field(
default=1,
metadata={"help": "Number of updates per batch."},
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
epsilon: float = field(
default=0.2,
Expand Down
86 changes: 58 additions & 28 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,19 @@ class RepeatRandomSampler(Sampler):
```
```txt
mini_repeat_count
- -
[4, 4, 3, 3, 0, 0, |
4, 4, 3, 3, 0, 0, |
4, 4, 3, 3, 0, 0, | repeat_count
4, 4, 3, 3, 0, 0] |
---- ---- ----
batch_size
mini_repeat_count = 3
- - -
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
repeat_count = 2
0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
--------- --------- --------- ---------
--------- --------- --------- ---------
--------- --------- --------- ---------
batch_size = 12
```
"""

Expand Down Expand Up @@ -348,10 +353,16 @@ def data_collator(features): # No data collation is needed in GRPO
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
self.num_updates = args.num_updates
self.epsilon = args.epsilon
self.beta = args.beta
self._update_remaning = 0

# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon = args.epsilon
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle.
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps

# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
Expand Down Expand Up @@ -509,28 +520,48 @@ def _set_signature_columns_if_needed(self):
self._signature_columns = ["prompt"]

def _get_train_sampler(self) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
# Returns a sampler that
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
# group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
# in group formation.
# 2. repeats the batch multiple times to allow reusing generaations across multiple updates. Refer to
# _prepare_inputs to see how the generations are stored and reused.

# | GPU 0 | GPU 1 | GPU 2 |
#
# global_step step <───────> num_generations=3
# <───────────> per_device_train_batch_size=4
# 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │
# 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ gradient_accumulation=3
# 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │
#
# 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ num_iterations=2:
# 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ reuse the batch once
# 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │
#
# 2 6 12 12 12 13 13 13 14 14 14 15 15 15
# 2 7 16 16 16 17 17 17 18 18 18 19 19 19
# 2 8 20 20 20 21 21 21 22 22 22 23 23 23
# ...
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=self.args.per_device_train_batch_size * self.num_generations // self.accelerator.num_processes,
repeat_count=self.num_updates,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
seed=self.args.seed,
)

def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
# See _get_train_sampler for an explanation of the sampler.
return RepeatRandomSampler(
data_source=eval_dataset,
mini_repeat_count=self.num_generations,
batch_size=self.args.per_device_eval_batch_size * self.accelerator.num_processes,
repeat_count=self.num_updates,
seed=self.args.seed,
)

Expand Down Expand Up @@ -578,13 +609,12 @@ def _move_model_to_vllm(self):
unwrapped_model.unmerge_adapter()

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
if self._update_remaning == 0:
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs = inputs
self._update_remaning = self.num_updates
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
else:
inputs = self._buffered_inputs
self._update_remaning -= 1
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
self._step += 1
return inputs

def _generate_and_score_completions(
Expand Down

0 comments on commit 946d1f6

Please sign in to comment.