Skip to content

Commit

Permalink
⚡ Fix GRPO PEFT (#2725)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Feb 12, 2025
1 parent 7347c29 commit 8122166
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
35 changes: 35 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,38 @@ def test_training_with_sync_ref_model(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
@require_peft
def test_training_vllm_and_peft(self):
"""Test that training works with vLLM for generation."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
use_vllm=True,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
31 changes: 23 additions & 8 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from accelerate.utils import is_deepspeed_available
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.utils.deprecation import deprecate_kwarg

from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead

Expand All @@ -37,8 +38,6 @@
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn.parallel.distributed import DistributedDataParallel

from .modeling_base import PreTrainedModelWrapper


# TODO: Add Abstract Base Class if more formats are added
@dataclass
Expand Down Expand Up @@ -176,18 +175,34 @@ def add_hooks(model: "DeepSpeedEngine") -> None:


@contextmanager
@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True)
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
accelerator: "Accelerator",
is_peft_model: bool = False,
gather_deepspeed3_params: bool = True,
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.
For ZeRO-3 models, we gather the weights once to speed up generation.
):
"""
Context manager to unwrap distributed or accelerated models for generation tasks.
Args:
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
Model to be unwrapped.
accelerator (`~accelerate.Accelerator`):
Accelerator instance managing the model.
gather_deepspeed3_params (`bool`, *optional*, defaults to `True`):
Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which
can be more memory-efficient but may lead to slower generation times.
Yields:
Unwrapped model.
Example:
```python
with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
generated_outputs = unwrapped_model.generate(input_ids)
```
"""
unwrapped_model = accelerator.unwrap_model(model)
if is_peft_model:
unwrapped_model.pretrained_model.disable_adapter()
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
Expand Down
16 changes: 15 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@


if is_peft_available():
from peft import PeftConfig, get_peft_model
from peft import PeftConfig, PeftModel, get_peft_model

if is_vllm_available():
from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -492,6 +492,20 @@ def _move_model_to_vllm(self):
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
state_dict = unwrapped_model._orig_mod.state_dict()
elif isinstance(unwrapped_model, PeftModel):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v
for k, v in state_dict.items()
if self.model.prefix not in k
}
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
Expand Down

0 comments on commit 8122166

Please sign in to comment.