diff --git a/README.md b/README.md index eb11a2ea8..03296c494 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels ``` +## FSDP optimizer overlap + +setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available from PyTorch 2.1.0 onward. + +```bash +torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap +``` + ### Fine-tuning using FSDP Only If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. diff --git a/docs/multi_gpu.md b/docs/multi_gpu.md index 0e961bb3e..5e02201ce 100644 --- a/docs/multi_gpu.md +++ b/docs/multi_gpu.md @@ -46,6 +46,14 @@ Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memor torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels ``` +## FSDP optimizer overlap + +setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0. + +```bash +torchrun --nnodes 1 --nproc_per_node 4 examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap +``` + ### Fine-tuning using FSDP Only If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs. diff --git a/src/llama_recipes/configs/fsdp.py b/src/llama_recipes/configs/fsdp.py index c89aff1f9..25a58d2b7 100644 --- a/src/llama_recipes/configs/fsdp.py +++ b/src/llama_recipes/configs/fsdp.py @@ -16,4 +16,4 @@ class fsdp_config: fsdp_cpu_offload: bool=False pure_bf16: bool = False optimizer: str= "AdamW" - + optimizer_overlap: bool = False \ No newline at end of file diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 2ec5c2340..dfe037d41 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -67,7 +67,16 @@ def main(**kwargs): torch.cuda.set_device(local_rank) clear_gpu_cache(local_rank) setup_environ_flags(rank) - + + #import _apply_optimizer_in_backward for FSDP optimizer overlap + optimizer_in_backward_available = False + if fsdp_config.optimizer_overlap: + try: + from torch.distributed.optim import _apply_optimizer_in_backward + optimizer_in_backward_available = True + except ImportError: + print("The required module for optimizer overlap in 'torch.distributed.optim' is not available, skipping applying optimizer overlap.") + # Load the pre-trained model and setup its configuration use_cache = False if train_config.enable_fsdp else None if train_config.enable_fsdp and train_config.low_cpu_fsdp: @@ -151,6 +160,7 @@ def main(**kwargs): device_id=torch.cuda.current_device(), limit_all_gathers=True, sync_module_states=train_config.low_cpu_fsdp, + use_orig_params = optimizer_in_backward_available, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) if train_config.low_cpu_fsdp and rank != 0 else None, ) @@ -207,21 +217,39 @@ def main(**kwargs): ) # Initialize the optimizer and learning rate scheduler + optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay} if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": + if optimizer_in_backward_available: + print(f"setting up optimizer overlap") + _apply_optimizer_in_backward( + optimizer_class=AnyPrecisionAdamW, + params=model.parameters(), + optimizer_kwargs = optim_kwargs, + register_hook=False, + ) optimizer = AnyPrecisionAdamW( model.parameters(), - lr=train_config.lr, momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, use_kahan_summation=False, - weight_decay=train_config.weight_decay, + **optim_kwargs, ) else: + if optimizer_in_backward_available: + print(f"setting up optimizer overlap") + _apply_optimizer_in_backward( + optimizer_class=optim.AdamW, + params=model.parameters(), + optimizer_kwargs = optim_kwargs, + register_hook=False, + ) + for p in model.parameters(): + assert hasattr(p, "_in_backward_optimizers") optimizer = optim.AdamW( model.parameters(), - lr=train_config.lr, - weight_decay=train_config.weight_decay, + **optim_kwargs, ) + scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) # Start the training process