From d0eba379467be1e5d0eec86b81e8cd0718dd5d6f Mon Sep 17 00:00:00 2001 From: some_ai <94305159+michP247@users.noreply.github.com> Date: Mon, 6 Jan 2025 01:19:43 -0400 Subject: [PATCH] fix --- flux_train.py | 4 +++- sd3_train.py | 6 +++++- sdxl_train.py | 22 +++++++++++----------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/flux_train.py b/flux_train.py index ce3b1ca23..7e3597d73 100644 --- a/flux_train.py +++ b/flux_train.py @@ -481,15 +481,17 @@ def train(args): for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + def create_grad_hook(p_name, p_group): def grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) optimizer.step_param(tensor, p_group) tensor.grad = None + return grad_hook parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) - elif args.optimizer_type == "ProdigyPlusScheduleFree": + elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree": # ProdigyPlus uses its internal fused_back_pass mechanism, pass for now pass else: diff --git a/sd3_train.py b/sd3_train.py index 3d50f13a9..7159ace8b 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -602,18 +602,22 @@ def train(args): if args.optimizer_type == "AdaFactor": import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group, param_name_group in zip(optimizer.param_groups, param_names): for parameter, param_name in zip(param_group["params"], param_name_group): if parameter.requires_grad: + def create_grad_hook(p_name, p_group): def grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) optimizer.step_param(tensor, p_group) tensor.grad = None + return grad_hook + parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) - elif args.optimizer_type == "ProdigyPlusScheduleFree": + elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree": # ProdigyPlus uses its internal fused_back_pass mechanism, pass for now pass else: diff --git a/sdxl_train.py b/sdxl_train.py index 786d3d5de..608c9be1a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -525,18 +525,18 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.optimizer_type == "AdaFactor": import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) - for param_group, param_name_group in zip(optimizer.param_groups, param_names): - for parameter, param_name in zip(param_group["params"], param_name_group): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: if parameter.requires_grad: - def create_grad_hook(p_name, p_group): - def grad_hook(tensor: torch.Tensor): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - accelerator.clip_grad_norm_(tensor, args.max_grad_norm) - optimizer.step_param(tensor, p_group) - tensor.grad = None - return grad_hook - parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group)) - elif args.optimizer_type == "ProdigyPlusScheduleFree": + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree": # ProdigyPlus uses its internal fused_back_pass mechanism, pass for now pass else: