Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fused_back_pass for prodigy-plus-schedule-free #1867

Open
wants to merge 2 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ def train(args):
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
if args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree" and args.fused_backward_pass:
args.optimizer_args.append("fused_back_pass=True")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

Expand Down Expand Up @@ -473,25 +475,29 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused

library.adafactor_fused.patch_adafactor_fused(optimizer)

for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
logger.warning(
f"Fused backward pass is not supported for optimizer type: {args.optimizer_type}. Ignoring."
)

elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
Expand Down
Loading