From 01e00ac1b085c562ccc14a1c166c43ccf90d2a83 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 19:45:44 +0900 Subject: [PATCH 01/12] Make a function get_my_scheduler() --- library/train_util.py | 98 ++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 51610e700..d6a5221a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4381,6 +4381,59 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is False + ): + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) @@ -4438,50 +4491,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, + scheduler = get_my_scheduler( + sample_sampler=args.scheduler, + v_parameterization=args.v_parameterization, ) - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True - pipeline = pipe_class( text_encoder=text_encoder, vae=vae, From 291c29caaf2f17e4c61b523522d7453df8a1c480 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 19:57:25 +0900 Subject: [PATCH 02/12] Added a function line_to_prompt_dict() and removed duplicated initializations --- library/train_util.py | 126 +++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 64 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d6a5221a7..b93b8ea48 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4439,6 +4439,55 @@ def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict['prompt'] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['width'] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['height'] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['seed'] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict['scale'] = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['negative_prompt'] = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['controlnet_image'] = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + return prompt_dict + def sample_images_common( pipe_class, accelerator, @@ -4517,73 +4566,22 @@ def sample_images_common( with torch.no_grad(): # with accelerator.autocast(): - for i, prompt in enumerate(prompts): + for i, prompt_dict in enumerate(prompts): if not accelerator.is_main_process: continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue - - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") if seed is not None: torch.manual_seed(seed) From cf876fcdb40d46c1bd21d50106ea44ada9f45671 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 20:15:04 +0900 Subject: [PATCH 03/12] Accept --ss to set sample_sampler dynamically --- library/train_util.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b93b8ea48..949a82065 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4477,6 +4477,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict['negative_prompt'] = m.group(1) continue + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['sample_sampler'] = m.group(1) + continue + m = re.match(r"cn (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict['controlnet_image'] = m.group(1) @@ -4540,17 +4545,19 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - scheduler = get_my_scheduler( - sample_sampler=args.scheduler, + schedulers: dict = {} + default_scheduler = get_my_scheduler( + sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) + schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=default_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -4582,11 +4589,18 @@ def sample_images_common( seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") + sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler) if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + scheduler = schedulers.get(sampler_name) + if scheduler is None: + scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,) + schedulers[sampler_name] = scheduler + pipeline.scheduler = scheduler + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -4604,6 +4618,7 @@ def sample_images_common( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") with accelerator.autocast(): latents = pipeline( prompt=prompt, From 40d917b0fecc2459dff8a3848d7ea0c7d6c21ccb Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:02:44 +0900 Subject: [PATCH 04/12] Removed incorrect comments --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 949a82065..edff4f49a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4478,12 +4478,12 @@ def line_to_prompt_dict(line: str) -> dict: continue m = re.match(r"ss (.+)", parg, re.IGNORECASE) - if m: # negative prompt + if m: prompt_dict['sample_sampler'] = m.group(1) continue m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt + if m: prompt_dict['controlnet_image'] = m.group(1) continue From fea810b437e0b4ea448e0ffa7d5933437bac6cae Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:44:57 +0900 Subject: [PATCH 05/12] Added --sample_at_first to generate sample images before training --- library/train_util.py | 19 +++++++++++++------ sdxl_train.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0f5033413..926e956ca 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2968,6 +2968,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) + parser.add_argument( + "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" + ) parser.add_argument( "--sample_every_n_epochs", type=int, @@ -4429,15 +4432,19 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): diff --git a/sdxl_train.py b/sdxl_train.py index 47bc6a420..a25da42d1 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -477,6 +477,19 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + [tokenizer1, tokenizer2], + [text_encoder1, text_encoder2], + unet, + ) + for m in training_models: m.train() From 5c150675bf1a4a0153b7fa404515ebe76f3e1698 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 21:46:47 +0900 Subject: [PATCH 06/12] Added --sample_at_first description --- docs/train_README-ja.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index c871f0769..d186bf243 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。 +- `--sample_at_first` + + 学習開始前にサンプル出力します。学習前との比較ができます。 + - `--sample_prompts` サンプル出力用プロンプトのファイルを指定します。 From 2c731418add79c303213ea884eb3d66bfe6b19d7 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sun, 29 Oct 2023 22:08:42 +0900 Subject: [PATCH 07/12] Added sample_images() for --sample_at_first --- fine_tune.py | 3 +++ train_controlnet.py | 14 ++++++++++++++ train_db.py | 2 ++ train_network.py | 4 +++- train_textual_inversion.py | 14 ++++++++++++++ 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/fine_tune.py b/fine_tune.py index a86a483a0..597678403 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -303,6 +303,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + for m in training_models: m.train() diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..d054d32eb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -373,6 +373,20 @@ def remove_model(old_ckpt_name): # training loop for epoch in range(num_train_epochs): + # For --sample_at_first + train_util.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + if is_main_process: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 diff --git a/train_db.py b/train_db.py index fd8e466e5..443cc5bfc 100644 --- a/train_db.py +++ b/train_db.py @@ -279,6 +279,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + train_util.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() # train==True is required to enable gradient_checkpointing diff --git a/train_network.py b/train_network.py index d50916b74..bf6597236 100644 --- a/train_network.py +++ b/train_network.py @@ -749,7 +749,9 @@ def remove_model(old_ckpt_name): current_epoch.value = epoch + 1 metadata["ss_epoch"] = str(epoch + 1) - + + # For --sample_at_first + self.sample_images(accelerator, args, epoch, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..2a347afa2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -534,6 +534,20 @@ def remove_model(old_ckpt_name): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + # For --sample_at_first + self.sample_images( + accelerator, + args, + epoch, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + for text_encoder in text_encoders: text_encoder.train() From da5a144589b93a3bd463291ce4c47fec5f2f0f6e Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Sat, 18 Nov 2023 07:47:27 -0500 Subject: [PATCH 08/12] Add PagedAdamW --- library/train_util.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..7e94158c9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2657,7 +2657,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -3373,7 +3373,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3477,6 +3477,20 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW".lower(): + print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW + except AttributeError: + raise AttributeError( + "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW32bit".lower(): print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: From c856ea42490df923773109180a86b7158227801c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 19 Nov 2023 12:11:36 -0500 Subject: [PATCH 09/12] Add attention processor --- library/original_unet.py | 47 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 240b85951..027210110 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -569,6 +569,9 @@ def __init__( self.use_memory_efficient_attention_mem_eff = False self.use_sdpa = False + # Attention processor + self.processor = None + def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff @@ -590,7 +593,28 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, hidden_states, context=None, mask=None): + def set_processor(self): + return self.processor + + def get_processor(self): + return self.processor + + def forward(self, hidden_states, context=None, mask=None, **kwargs): + if self.processor is not None: + ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) = translate_attention_names_from_diffusers( + hidden_states=hidden_states, context=context, mask=mask, **kwargs + ) + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=context, + attention_mask=mask, + **kwargs + ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: @@ -703,6 +727,21 @@ def forward_sdpa(self, x, context=None, mask=None): out = self.to_out[0](out) return out +def translate_attention_names_from_diffusers( + hidden_states: torch.FloatTensor, + context: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + # HF naming + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None +): + # translate from hugging face diffusers + context = context if context is not None else encoder_hidden_states + + # translate from hugging face diffusers + mask = mask if mask is not None else attention_mask + + return hidden_states, context, mask # feedforward class GEGLU(nn.Module): @@ -1331,7 +1370,7 @@ def __init__( self.out_channels = OUT_CHANNELS self.sample_size = sample_size - self.prepare_config() + self.prepare_config(sample_size=sample_size) # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1418,8 +1457,8 @@ def __init__( self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility - def prepare_config(self): - self.config = SimpleNamespace() + def prepare_config(self, *args, **kwargs): + self.config = SimpleNamespace(**kwargs) @property def dtype(self) -> torch.dtype: From 39bb319d4cac05d7da054ee726f86061e629574d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 29 Nov 2023 12:42:12 +0900 Subject: [PATCH 10/12] fix to work with cfg scale=1 --- sdxl_gen_img.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 78b90f8c3..ab5399842 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -504,7 +504,8 @@ def __call__( uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -567,9 +568,11 @@ def __call__( text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) From ee46134fa7f9b471b4aca90e4aba13102ed6cd02 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Dec 2023 18:24:50 +0900 Subject: [PATCH 11/12] update readme --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index 0edaca25f..c4b91ea15 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,30 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Dec 3, 2023 / 2023/12/3 + +- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details. +- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - The default values are same as the previous version. +- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`. + - `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages. + - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. + - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. + - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. + +- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。 +- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - デフォルト値は前のバージョンと同じです。 +- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。 + - `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。 + - `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。 + - `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。 + - `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。 + ### Nov 5, 2023 / 2023/11/5 - `sdxl_train.py` now supports different learning rates for each Text Encoder. From f24a3b52828cf6747940b6047d076404f08309f4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Dec 2023 21:15:30 +0900 Subject: [PATCH 12/12] show seed in generating samples --- library/train_util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index a94562a33..da588980d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4692,6 +4692,8 @@ def sample_images_common( print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") with accelerator.autocast(): latents = pipeline( prompt=prompt,