Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michP247 committed Jan 6, 2025
1 parent 8cee727 commit d0eba37
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
4 changes: 3 additions & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,17 @@ def train(args):
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
Expand Down
6 changes: 5 additions & 1 deletion sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,18 +602,22 @@ def train(args):
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)

for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook

parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
Expand Down
22 changes: 11 additions & 11 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,18 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":

def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None

parameter.register_post_accumulate_grad_hook(__grad_hook)
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
Expand Down

0 comments on commit d0eba37

Please sign in to comment.