Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to support multi-device VLLM inference in the GRPO Trainer #2922

Open
0x404 opened this issue Feb 21, 2025 · 0 comments
Open

How to support multi-device VLLM inference in the GRPO Trainer #2922

0x404 opened this issue Feb 21, 2025 · 0 comments
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO

Comments

@0x404
Copy link

0x404 commented Feb 21, 2025

if self.accelerator.is_main_process:
vllm_device = self.args.vllm_device
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
raise ValueError(
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
)
# Check that the requested device is not also used for training
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
warnings.warn(
f"The requested device {vllm_device} is also being used for training. For higher throughput "
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
"If this is intentional, you may ignore this warning but should adjust "
"`vllm_gpu_memory_utilization` accordingly."
)

In the current GRPO implementation, VLLM can only run on a single GPU, which becomes a performance bottleneck. For example, in an 8-GPU setup, the remaining 7 GPUs have to wait for 1 GPU to complete inference, and it also can't accommodate larger models.

How can we enable VLLM to run on multiple GPUs? The only concern is that we need to figure out a way to update the parameters across multiple GPUs each time the model is reloaded:

def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
# This must be done after loading weights to ensure they correspond to the merged state.
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

@github-actions github-actions bot added 🏋 GRPO Related to GRPO ✨ enhancement New feature or request labels Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
✨ enhancement New feature or request 🏋 GRPO Related to GRPO
Projects
None yet
Development

No branches or pull requests

1 participant