Skip to content

Commit

Permalink
fixing the logic
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidShojanazeri committed Oct 30, 2023
1 parent ea5d0d4 commit c63428d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(**kwargs):
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=train_config.low_cpu_fsdp,
use_orig_params = True if optimizer_in_backward_available else False,
use_orig_params = optimizer_in_backward_available,
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
if train_config.low_cpu_fsdp and rank != 0 else None,
)
Expand Down Expand Up @@ -217,37 +217,37 @@ def main(**kwargs):
)

# Initialize the optimizer and learning rate scheduler
optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay}
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
if optimizer_in_backward_available:
print(f"setting up optimizer overlap")
_apply_optimizer_in_backward(
optimizer_class=AnyPrecisionAdamW,
params=model.parameters(),
optimizer_kwargs = optim_kwargs,
register_hook=False,
)
optimizer = AnyPrecisionAdamW(
model.parameters(),
lr=train_config.lr,
momentum_dtype=torch.bfloat16,
variance_dtype=torch.bfloat16,
use_kahan_summation=False,
weight_decay=train_config.weight_decay,
**optim_kwargs,
)
else:
if optimizer_in_backward_available:
print(f"setting up optimizer overlap")
_apply_optimizer_in_backward(
optimizer_class=optim.AdamW,
params=model.parameters(),
lr=train_config.lr,
optimizer_kwargs = optim_kwargs,
register_hook=False,
)
for p in model.parameters():
assert hasattr(p, "_in_backward_optimizers")
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
**optim_kwargs,
)

scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
Expand Down

0 comments on commit c63428d

Please sign in to comment.