From 6231aa91e251b9321cb630a4cccb5c96bec49a1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 5 Nov 2023 19:09:17 +0900 Subject: [PATCH] common lr logging, set default None to ddp_timeout --- fine_tune.py | 9 ++------- library/train_util.py | 35 ++++++++++++++++++++++++++++++++--- sdxl_train.py | 37 ++++++++++--------------------------- train_db.py | 9 ++------- 4 files changed, 46 insertions(+), 44 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index a86a483a0..52e84c43f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -408,13 +408,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss) diff --git a/library/train_util.py b/library/train_util.py index 0f5033413..cc9ac4555 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2864,7 +2864,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 parser.add_argument( - "--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)", + "--ddp_timeout", + type=int, + default=None, + help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", ) parser.add_argument( "--clip_skip", @@ -3806,12 +3809,15 @@ def prepare_accelerator(args: argparse.Namespace): if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + kwargs_handlers = ( + None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))] + ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, project_dir=logging_dir, - kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))], + kwargs_handlers=kwargs_handlers, ) return accelerator @@ -4401,6 +4407,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps +def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True): + names = [] + if including_unet: + names.append("unet") + names.append("text_encoder1") + names.append("text_encoder2") + + append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) + + +def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): + lrs = lr_scheduler.get_last_lr() + + for lr_index in range(len(lrs)): + name = names[lr_index] + logs["lr/" + name] = float(lrs[lr_index]) + + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + logs["lr/d*lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] + ) + + # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -4718,7 +4747,7 @@ def __init__(self): self.loss_list: List[float] = [] self.loss_total: float = 0.0 - def add(self, *, epoch:int, step: int, loss: float) -> None: + def add(self, *, epoch: int, step: int, loss: float) -> None: if epoch == 0: self.loss_list.append(loss) else: diff --git a/sdxl_train.py b/sdxl_train.py index 47bc6a420..fd775624e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -74,33 +74,22 @@ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type): - lrs = lr_scheduler.get_last_lr() - - lr_index = 0 + names = [] block_index = 0 - while lr_index < len(lrs): + while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2: if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR: - name = f"block{block_index}" if block_lrs[block_index] == 0: block_index += 1 continue + names.append(f"block{block_index}") elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR: - name = "text_encoder1" + names.append("text_encoder1") elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1: - name = "text_encoder2" - else: - raise ValueError(f"unexpected block_index: {block_index}") + names.append("text_encoder2") block_index += 1 - logs["lr/" + name] = float(lrs[lr_index]) - - if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): - logs["lr/d*lr/" + name] = ( - lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] - ) - - lr_index += 1 + train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names) def train(args): @@ -287,8 +276,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.gradient_checkpointing: text_encoder1.gradient_checkpointing_enable() text_encoder2.gradient_checkpointing_enable() - lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train - lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train train_text_encoder1 = lr_te1 > 0 train_text_encoder2 = lr_te2 > 0 @@ -647,15 +636,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.logging_dir is not None: logs = {"loss": current_loss} if block_lrs is None: - logs["lr"] = float(lr_scheduler.get_last_lr()[0]) - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet) else: - append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) + append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs accelerator.log(logs, step=global_step) diff --git a/train_db.py b/train_db.py index fd8e466e5..7fbbc18ac 100644 --- a/train_db.py +++ b/train_db.py @@ -394,13 +394,8 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] - ) + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss)