Skip to content

Commit

Permalink
Merge branch 'main' into multi-step-grpi
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Feb 20, 2025
2 parents 946d1f6 + a92e00e commit 035e4be
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 15 deletions.
83 changes: 82 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


if is_peft_available():
from peft import LoraConfig
from peft import LoraConfig, PeftModel


class RepeatRandomSamplerTester(unittest.TestCase):
Expand Down Expand Up @@ -236,6 +236,57 @@ def test_training_peft(self):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")

@require_peft
def test_training_peft_with_gradient_checkpointing(self):
"""Test that training works with PEFT and gradient checkpointing enabled."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues
use_cache=False, # Required for gradient checkpointing
)

lora_config = LoraConfig(
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
gradient_checkpointing=True, # Enable gradient checkpointing
report_to="none",
)
trainer = GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
)

# Verify gradient checkpointing is enabled
self.assertIsInstance(trainer.model, PeftModel)

# Store initial parameters to check which ones change
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that only LoRA parameters have changed, base model parameters remain unchanged
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" in n.lower(): # LoRA parameters should change
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
else: # Base model parameters should not change
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")

def test_training_different_reward_model(self):
# Use a reward model different from the model: different chat template, tokenization, etc.
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
Expand Down Expand Up @@ -603,6 +654,36 @@ def test_training_with_sync_ref_model(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_beta_zero_no_ref_model_and_no_kl(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0.0, # set beta to 0 to test the case where the reference model is not used
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
@require_peft
Expand Down
41 changes: 41 additions & 0 deletions trl/extras/profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time

from transformers import is_wandb_available


if is_wandb_available():
import wandb


def profiling_decorator(func):
"""
Decorator to profile a function and log the time taken to execute it.
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time

if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result

return wrapper
10 changes: 7 additions & 3 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ class GRPOConfig(TrainingArguments):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping
Epsilon value for clipping.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down Expand Up @@ -222,7 +223,10 @@ class GRPOConfig(TrainingArguments):
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
metadata={
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
"training speed."
},
)
num_iterations: int = field(
default=1,
Expand Down
67 changes: 56 additions & 11 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -284,19 +285,30 @@ def __init__(
"This argument can only be used when the `model` argument is a string."
)

self.beta = args.beta

if peft_config is not None:
if not is_peft_available():
raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
model = get_peft_model(model, peft_config)

# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)

# Reference model
if is_deepspeed_zero3_enabled():
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)

# Processing class
if processing_class is None:
Expand Down Expand Up @@ -565,7 +577,30 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
seed=self.args.seed,
)

def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False

# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable()

gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)

if use_reentrant:
model.enable_input_require_grads()

return model

# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Expand All @@ -577,6 +612,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

@profiling_decorator
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
Expand Down Expand Up @@ -608,6 +644,7 @@ def _move_model_to_vllm(self):
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs)
Expand Down Expand Up @@ -698,7 +735,9 @@ def _generate_and_score_completions(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)

if self.ref_model is not None:
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
Expand Down Expand Up @@ -805,6 +844,7 @@ def _generate_and_score_completions(
"advantages": advantages,
}

@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
Expand All @@ -819,8 +859,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
if self.beta != 0.0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)

# Compute the loss
advantages = inputs["advantages"]
Expand All @@ -830,15 +873,17 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
per_token_loss = per_token_loss + self.beta * per_token_kl
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()

# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)

mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
if self.beta != 0.0:
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

# Clip ratio
is_clipped = (per_token_loss1 < per_token_loss2).float()
Expand Down

0 comments on commit 035e4be

Please sign in to comment.