Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding optimizer overlap for FSDP #203

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor
torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
```

## FSDP optimizer overlap

setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available from PyTorch 2.1.0 onward.

```bash
torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap
```

### Fine-tuning using FSDP Only

If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.
Expand Down
8 changes: 8 additions & 0 deletions docs/multi_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor
torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
```

## FSDP optimizer overlap

setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0.

```bash
torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap
```

### Fine-tuning using FSDP Only

If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.
Expand Down
2 changes: 1 addition & 1 deletion src/llama_recipes/configs/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class fsdp_config:
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"

optimizer_overlap: bool = False
38 changes: 33 additions & 5 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,16 @@ def main(**kwargs):
torch.cuda.set_device(local_rank)
clear_gpu_cache(local_rank)
setup_environ_flags(rank)


#import _apply_optimizer_in_backward for FSDP optimizer overlap
optimizer_in_backward_available = False
if fsdp_config.optimizer_overlap:
try:
from torch.distributed.optim import _apply_optimizer_in_backward
optimizer_in_backward_available = True
except ImportError:
print("The required module for optimizer overlap in 'torch.distributed.optim' is not available, skipping applying optimizer overlap.")

# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
Expand Down Expand Up @@ -151,6 +160,7 @@ def main(**kwargs):
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
use_orig_params = optimizer_in_backward_available,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
Expand Down Expand Up @@ -207,21 +217,39 @@ def main(**kwargs):
)

# Initialize the optimizer and learning rate scheduler
optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay}
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
if optimizer_in_backward_available:
print(f"setting up optimizer overlap")
_apply_optimizer_in_backward(
optimizer_class=AnyPrecisionAdamW,
params=model.parameters(),
optimizer_kwargs = optim_kwargs,
register_hook=False,
)
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=train_config.lr,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=False,
weight_decay=train_config.weight_decay,
**optim_kwargs,
)
else:
if optimizer_in_backward_available:
print(f"setting up optimizer overlap")
_apply_optimizer_in_backward(
optimizer_class=optim.AdamW,
params=model.parameters(),
optimizer_kwargs = optim_kwargs,
register_hook=False,
)
for p in model.parameters():
assert hasattr(p, "_in_backward_optimizers")
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
**optim_kwargs,
)

scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

# Start the training process
Expand Down