From c63428da3a6dc3a0994c675da010aa32571ead56 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Mon, 30 Oct 2023 04:43:50 +0000 Subject: [PATCH] fixing the logic --- src/llama_recipes/finetuning.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 02d2a70a3..84d8bf793 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -160,7 +160,7 @@ def main(**kwargs): device_id=torch.cuda.current_device(), limit_all_gathers=True, sync_module_states=train_config.low_cpu_fsdp, - use_orig_params = True if optimizer_in_backward_available else False, + use_orig_params = optimizer_in_backward_available, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) if train_config.low_cpu_fsdp and rank != 0 else None, ) @@ -217,21 +217,22 @@ def main(**kwargs): ) # Initialize the optimizer and learning rate scheduler + optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay} if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": if optimizer_in_backward_available: print(f"setting up optimizer overlap") _apply_optimizer_in_backward( optimizer_class=AnyPrecisionAdamW, params=model.parameters(), + optimizer_kwargs = optim_kwargs, register_hook=False, ) optimizer = AnyPrecisionAdamW( model.parameters(), - lr=train_config.lr, momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, use_kahan_summation=False, - weight_decay=train_config.weight_decay, + **optim_kwargs, ) else: if optimizer_in_backward_available: @@ -239,15 +240,14 @@ def main(**kwargs): _apply_optimizer_in_backward( optimizer_class=optim.AdamW, params=model.parameters(), - lr=train_config.lr, + optimizer_kwargs = optim_kwargs, register_hook=False, ) for p in model.parameters(): assert hasattr(p, "_in_backward_optimizers") optimizer = optim.AdamW( model.parameters(), - lr=train_config.lr, - weight_decay=train_config.weight_decay, + **optim_kwargs, ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)