Skip to content

Commit

Permalink
đŸ§© PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gath…
Browse files Browse the repository at this point in the history
…ering optional (#2557)

* PPO/RLOO/OnlineDPO: add ds3_gather_for_generation argument to control weights gathering for generation

* code formatting

* rephrase and document

* more doc

* style [ci skip]

* Trigger CI

---------

Co-authored-by: Quentin GallouĂ©dec <[email protected]>
Co-authored-by: Quentin GallouĂ©dec <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent a5c88d6 commit d4222a1
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 8 deletions.
38 changes: 38 additions & 0 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,41 @@ training_args = SFTConfig(..., packing=True, max_seq_length=512)
Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).

</Tip>

## Disabling model gathering for generation in online methods

When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).

If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:

<hfoptions id="ds3_gather_for_generation">
<hfoption id="Online DPO">

```python
from trl import OnlineDPOConfig

training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
```

</hfoption>
<hfoption id="PPO">

```python
from trl import PPOConfig

training_args = PPOConfig(..., ds3_gather_for_generation=False)
```

</hfoption>
<hfoption id="RLOO">

```python
from trl import RLOOConfig

training_args = RLOOConfig(..., ds3_gather_for_generation=False)
```

</hfoption>
</hfoptions>

This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
7 changes: 6 additions & 1 deletion trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def add_hooks(model: "DeepSpeedEngine") -> None:

@contextmanager
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
accelerator: "Accelerator",
is_peft_model: bool = False,
gather_deepspeed3_params: bool = True,
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
"""Context manager to unwrap a model for generation.
For ZeRO-3 models, we gather the weights once to speed up generation.
Expand All @@ -181,6 +184,8 @@ def unwrap_model_for_generation(
if is_peft_model:
unwrapped_model.pretrained_model.disable_adapter()
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
Expand Down
16 changes: 14 additions & 2 deletions trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class OnlineDPOConfig(TrainingArguments):
Whether to disable dropout in the model and reference model.
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use the vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""

learning_rate: float = field(
Expand Down Expand Up @@ -114,8 +118,8 @@ class OnlineDPOConfig(TrainingArguments):
metadata={
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
"the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by "
"τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is "
"selected for each new epoch and the last β is used for the rest of the epochs."
"τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β "
"is selected for each new epoch and the last β is used for the rest of the epochs."
},
)
loss_type: str = field(
Expand All @@ -140,6 +144,14 @@ class OnlineDPOConfig(TrainingArguments):
"(`pip install vllm`)."
},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
},
)

def __post_init__(self):
super().__post_init__()
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def _generate(self, model, prompts):
inputs = self._prepare_inputs(inputs)
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
output = unwrapped_model.generate(
input_ids=prompt_ids,
attention_mask=prompt_mask,
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class PPOConfig(OnPolicyConfig):
Discount factor.
lam (`float`, *optional*, defaults to `0.95`):
Lambda value for GAE.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""

exp_name: str = field(
Expand Down Expand Up @@ -103,3 +107,11 @@ class PPOConfig(OnPolicyConfig):
default=0.95,
metadata={"help": "Lambda value for GAE."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
},
)
8 changes: 6 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,9 @@ def repeat_generator():
scores = []
sequence_lengths = []
values = []
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
Expand Down Expand Up @@ -688,7 +690,9 @@ def generate_completions(self, sampling: bool = False):
)

table = defaultdict(list)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
with torch.no_grad():
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class RLOOConfig(OnPolicyConfig):
Whether to normalize advantages.
token_level_kl (`bool`, *optional*, defaults to `True`):
Whether to use token-level KL penalty or sequence-level KL penalty.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""

exp_name: str = field(
Expand Down Expand Up @@ -96,3 +100,11 @@ class RLOOConfig(OnPolicyConfig):
default=False,
metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
},
)
8 changes: 6 additions & 2 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def repeat_generator():
sequence_lengths = []

# Generate responses and compute logprobs
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model,
queries,
Expand Down Expand Up @@ -565,7 +567,9 @@ def generate_completions(self, sampling: bool = False):
)

table = defaultdict(list)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch["input_ids"]
with torch.no_grad():
Expand Down

0 comments on commit d4222a1

Please sign in to comment.