diff --git a/README.md b/README.md index 118096fb0..1fd28fbaa 100644 --- a/README.md +++ b/README.md @@ -630,6 +630,14 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum * 2023/11/01 (v22.2.0) - Merge latest sd-script dev branch + - `sdxl_train.py` now supports different learning rates for each Text Encoder. + - Example: + - `--learning_rate 1e-6`: train U-Net only + - `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version) + - `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates + - `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only + - `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only + - `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only * 2023/10/10 (v22.1.0) - Remove support for torch 1 to align with kohya_ss sd-scripts code base. - Add Intel ARC GPU support with IPEX support on Linux / WSL diff --git a/fine_tune.py b/fine_tune.py index 2ecb4ff36..a86a483a0 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -10,10 +10,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -32,6 +35,7 @@ get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) @@ -192,14 +196,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + + trainable_params = [] + if args.learning_rate_te is None or not args.train_text_encoder: + for m in training_models: + trainable_params.extend(m.parameters()) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる @@ -288,6 +298,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -295,7 +306,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく @@ -339,7 +349,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -348,6 +358,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: @@ -405,17 +417,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): ) accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -474,6 +485,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) return parser diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index ce6e66955..b3c5cc423 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch): def main(args): + r""" + transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように @@ -65,6 +68,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) return input_ids GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + """ print(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) @@ -81,7 +85,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) def run_batch(path_imgs): imgs = [im for _, im in path_imgs] - curr_batch_size[0] = len(path_imgs) + # curr_batch_size[0] = len(path_imgs) inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index af08c5375..1bccb1d3b 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -215,7 +215,7 @@ def setup_parser() -> argparse.ArgumentParser: help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)", ) parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") parser.add_argument( "--bucket_reso_steps", type=int, diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 0ec683a23..a596a0494 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -65,10 +65,13 @@ import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -954,7 +957,7 @@ def __call__( text_emb_last = torch.stack(text_emb_last) else: text_emb_last = text_embeddings - + for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) @@ -2363,12 +2366,19 @@ def __getattr__(self, item): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -2379,31 +2389,32 @@ def __getattr__(self, item): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to(text_encoder, unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -2417,6 +2428,7 @@ def __getattr__(self, item): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to(text_encoder, unet, weights_sd, dtype, device) @@ -2712,9 +2724,18 @@ def resize_images(imgs, size): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -3367,10 +3388,19 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b779ed1e5..0cc444d55 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1/torch.sqrt(snr_t) + loss = weight * loss + return loss # TODO train_utilと分散しているのでどちらかに寄せる @@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/library/train_util.py b/library/train_util.py index 51610e700..0f5033413 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +import datetime import importlib import json import pathlib @@ -18,7 +19,7 @@ Tuple, Union, ) -from accelerate import Accelerator +from accelerate import Accelerator, InitProcessGroupKwargs import gc import glob import math @@ -148,6 +149,13 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: + if max_size is not None: + if max_reso is not None: + assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso" + assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso" + if min_size is not None: + assert max_size >= min_size, "the max_size should be larger than the min_size" + self.no_upscale = no_upscale if max_reso is None: self.max_reso = None @@ -2649,7 +2657,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2855,6 +2863,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)", + ) parser.add_argument( "--clip_skip", type=int, @@ -3359,7 +3370,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3463,6 +3474,20 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW32bit".lower(): + print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW32bit + except AttributeError: + raise AttributeError( + "No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "SGDNesterov".lower(): print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}") if "momentum" not in optimizer_kwargs: @@ -3786,6 +3811,7 @@ def prepare_accelerator(args: argparse.Namespace): mixed_precision=args.mixed_precision, log_with=log_with, project_dir=logging_dir, + kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))], ) return accelerator @@ -4685,3 +4711,21 @@ def __call__(self, examples): dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] + + +class LossRecorder: + def __init__(self): + self.loss_list: List[float] = [] + self.loss_total: float = 0.0 + + def add(self, *, epoch:int, step: int, loss: float) -> None: + if epoch == 0: + self.loss_list.append(loss) + else: + self.loss_total -= self.loss_list[step] + self.loss_list[step] = loss + self.loss_total += loss + + @property + def moving_average(self) -> float: + return self.loss_total / len(self.loss_list) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab2b6b3d6..c31ae0072 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -17,10 +17,13 @@ import diffusers import numpy as np import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -1534,12 +1537,20 @@ def __getattr__(self, item): network_default_muls = [] network_pre_calc = args.network_pre_calc + # merge関連の引数を統合する + if args.network_merge: + network_merge = len(args.network_module) # all networks are merged + elif args.network_merge_n_models: + network_merge = args.network_merge_n_models + else: + network_merge = 0 + print(f"network_merge: {network_merge}") + for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1550,31 +1561,32 @@ def __getattr__(self, item): key, value = net_arg.split("=") net_kwargs[key] = value - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) + if args.network_weights is None or len(args.network_weights) <= i: + raise ValueError("No weight. Weight is required.") - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) if network is None: return mergeable = network.is_mergeable() - if args.network_merge and not mergeable: + if network_merge and not mergeable: print("network is not mergiable. ignore merge option.") - if not args.network_merge or not mergeable: + if not mergeable or i >= network_merge: + # not merging network.apply_to([text_encoder1, text_encoder2], unet) info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい print(f"weights are loaded: {info}") @@ -1588,6 +1600,7 @@ def __getattr__(self, item): network.backup_weights() networks.append(network) + network_default_muls.append(network_mul) else: network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) @@ -1864,9 +1877,18 @@ def resize_images(imgs, size): size = None for i, network in enumerate(networks): - if i < 3: + if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes: np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] + + if args.network_regional_mask_max_color_codes: + # カラーコードでマスクを指定する + ch0 = (i + 1) & 1 + ch1 = ((i + 1) >> 1) & 1 + ch2 = ((i + 1) >> 2) & 1 + np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2) + np_mask = np_mask.astype(np.uint8) * 255 + else: + np_mask = np_mask[:, :, i] size = np_mask.shape else: np_mask = np.full(size, 255, dtype=np.uint8) @@ -2615,10 +2637,19 @@ def setup_parser() -> argparse.ArgumentParser: "--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数" ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument( + "--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする" + ) parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") parser.add_argument( "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" ) + parser.add_argument( + "--network_regional_mask_max_color_codes", + type=int, + default=None, + help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)", + ) parser.add_argument( "--textual_inversion_embeddings", type=str, diff --git a/sdxl_train.py b/sdxl_train.py index 7bde3cab7..47bc6a420 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -10,10 +10,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -34,6 +37,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -271,10 +275,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.wait_for_everyone() # 学習を準備する:モデルを適切な状態にする - training_models = [] if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - training_models.append(unet) + train_unet = args.learning_rate > 0 + train_text_encoder1 = False + train_text_encoder2 = False if args.train_text_encoder: # TODO each option for two text encoders? @@ -282,10 +287,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.gradient_checkpointing: text_encoder1.gradient_checkpointing_enable() text_encoder2.gradient_checkpointing_enable() - training_models.append(text_encoder1) - training_models.append(text_encoder2) - # set require_grad=True later + lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train + lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train + train_text_encoder1 = lr_te1 > 0 + train_text_encoder2 = lr_te2 > 0 + + # caching one text encoder output is not supported + if not train_text_encoder1: + text_encoder1.to(weight_dtype) + if not train_text_encoder2: + text_encoder2.to(weight_dtype) + text_encoder1.requires_grad_(train_text_encoder1) + text_encoder2.requires_grad_(train_text_encoder2) + text_encoder1.train(train_text_encoder1) + text_encoder2.train(train_text_encoder2) else: + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) text_encoder1.requires_grad_(False) text_encoder2.requires_grad_(False) text_encoder1.eval() @@ -294,7 +312,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): train_dataset_group.cache_text_encoder_outputs( (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), @@ -310,30 +328,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - for m in training_models: - m.requires_grad_(True) + unet.requires_grad_(train_unet) + if not train_unet: + unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - if block_lrs is None: - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params + training_models = [] + params_to_optimize = [] + if train_unet: + training_models.append(unet) + if block_lrs is None: + params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate}) + else: + params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs)) - # calculate number of trainable parameters - n_params = 0 - for p in params: - n_params += p.numel() - else: - params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net - for m in training_models[1:]: # Text Encoders if exists - params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate}) + if train_text_encoder1: + training_models.append(text_encoder1) + params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate}) + if train_text_encoder2: + training_models.append(text_encoder2) + params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate}) - # calculate number of trainable parameters - n_params = 0 - for params in params_to_optimize: - for p in params["params"]: - n_params += p.numel() + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -385,18 +406,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if args.train_text_encoder: - unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler - ) - - # transform DDP after prepare - text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if train_unet: + unet = accelerator.prepare(unet) (unet,) = train_util.transform_models_if_DDP([unet]) - text_encoder1.to(weight_dtype) - text_encoder2.to(weight_dtype) + if train_text_encoder1: + text_encoder1 = accelerator.prepare(text_encoder1) + (text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1]) + if train_text_encoder2: + text_encoder2 = accelerator.prepare(text_encoder2) + (text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2]) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -452,6 +472,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -459,10 +480,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for m in training_models: m.train() - loss_total = 0 for step, batch in enumerate(train_dataloader): current_step.value = global_step - with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) else: @@ -548,7 +568,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss: + if ( + args.min_snr_gamma + or args.scale_v_pred_loss_like_noise_pred + or args.v_pred_like_loss + or args.debiased_estimation_loss + ): # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -559,6 +584,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: @@ -632,17 +659,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.log(logs, step=global_step) - # TODO moving averageにする - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -726,6 +752,19 @@ def setup_parser() -> argparse.ArgumentParser: custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument( + "--learning_rate_te1", + type=float, + default=None, + help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率", + ) + parser.add_argument( + "--learning_rate_te2", + type=float, + default=None, + help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率", + ) + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") parser.add_argument( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0df61e848..54abf697c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -44,6 +44,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train @@ -350,8 +351,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -465,6 +465,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -500,14 +502,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +515,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 79920a972..f00f10eaa 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -40,6 +40,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite @@ -323,8 +324,7 @@ def train(args): "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -435,6 +435,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -470,14 +472,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -488,7 +485,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 2de57c0ac..199c4e032 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -70,14 +70,16 @@ def cache_text_encoder_outputs_if_needed( if torch.cuda.is_available(): torch.cuda.empty_cache() - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + with accelerator.autocast(): + dataset.cache_text_encoder_outputs( + tokenizers, + text_encoders, + accelerator.device, + weight_dtype, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/train_controlnet.py b/train_controlnet.py index 5bc8d399c..bbd915cb3 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -337,8 +337,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # function for saving/removing @@ -500,14 +499,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: @@ -518,7 +512,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index a1b9cac8b..fd8e466e5 100644 --- a/train_db.py +++ b/train_db.py @@ -11,10 +11,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -35,6 +38,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) # perlin_noise, @@ -163,11 +167,17 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: - # wightout list, adamw8bit is crashed - trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + if args.learning_rate_te is None: + # wightout list, adamw8bit is crashed + trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = [ + {"params": list(unet.parameters()), "lr": args.learning_rate}, + {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te}, + ] else: trainable_params = unet.parameters() - + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する @@ -264,8 +274,7 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -336,6 +345,8 @@ def train(args): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -392,21 +403,16 @@ def train(args): ) accelerator.log(logs, step=global_step) - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() @@ -464,6 +470,12 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--learning_rate_te", + type=float, + default=None, + help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ", + ) parser.add_argument( "--no_token_padding", action="store_true", diff --git a/train_network.py b/train_network.py index 2232a384a..d50916b74 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) @@ -108,6 +109,9 @@ def load_tokenizer(self, args): def is_text_encoder_outputs_cached(self, args): return False + def is_train_text_encoder(self, args): + return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + def cache_text_encoder_outputs_if_needed( self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): @@ -309,7 +313,7 @@ def train(self, args): args.scale_weight_norms = False train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: @@ -402,6 +406,8 @@ def train(self, args): unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler ) + for t_enc in text_encoders: + t_enc.to(accelerator.device, dtype=weight_dtype) elif train_text_encoder: if len(text_encoders) > 1: t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -528,6 +534,7 @@ def train(self, args): "ss_min_snr_gamma": args.min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, + "ss_debiased_estimation": bool(args.debiased_estimation_loss), } if use_user_config: @@ -703,8 +710,7 @@ def train(self, args): "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) - loss_list = [] - loss_total = 0.0 + loss_recorder = train_util.LossRecorder() del train_dataset_group # callback for step start @@ -765,7 +771,7 @@ def remove_model(old_ckpt_name): latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings( @@ -811,6 +817,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -854,14 +862,9 @@ def remove_model(old_ckpt_name): remove_model(remove_ckpt_name) current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.scale_weight_norms: @@ -875,7 +878,7 @@ def remove_model(old_ckpt_name): break if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 252add536..6b6e7f5a0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,6 +32,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) imagenet_templates_small = [ @@ -582,6 +583,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 525e612f1..8dd5c672f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -34,6 +34,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -471,6 +472,8 @@ def remove_model(old_ckpt_name): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし