diff --git a/README.md b/README.md index 0edaca25f..c4b91ea15 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,30 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Dec 3, 2023 / 2023/12/3 + +- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details. +- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - The default values are same as the previous version. +- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`. + - `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages. + - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. + - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. + - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. + +- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。 +- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - デフォルト値は前のバージョンと同じです。 +- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。 + - `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。 + - `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。 + - `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。 + - `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。 + ### Nov 5, 2023 / 2023/11/5 - `sdxl_train.py` now supports different learning rates for each Text Encoder. diff --git a/fine_tune.py b/fine_tune.py index 52e84c43f..b07876776 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 965edd7e2..fbf328e83 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -160,7 +160,9 @@ def main(args): tag_freq = {} - undesired_tags = set(args.undesired_tags.split(",")) + caption_separator = args.caption_separator + stripped_caption_separator = caption_separator.strip() + undesired_tags = set(args.undesired_tags.split(stripped_caption_separator)) def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) @@ -194,7 +196,7 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += ", " + tag_name + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= args.character_threshold: tag_name = character_tags[i - len(general_tags)] @@ -203,18 +205,18 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += ", " + tag_name + character_tag_text += caption_separator + tag_name combined_tags.append(tag_name) # 先頭のカンマを取る if len(general_tag_text) > 0: - general_tag_text = general_tag_text[2:] + general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: - character_tag_text = character_tag_text[2:] + character_tag_text = character_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension - tag_text = ", ".join(combined_tags) + tag_text = caption_separator.join(combined_tags) if args.append_tags: # Check if file exists @@ -224,13 +226,13 @@ def run_batch(path_imgs): existing_content = f.read().strip("\n") # Remove newlines # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] # Check and remove repeating tags in tag_text new_tags = [tag for tag in combined_tags if tag not in existing_tags] # Create new tag_text - tag_text = ", ".join(existing_tags + new_tags) + tag_text = caption_separator.join(existing_tags + new_tags) with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--caption_separator", + type=str, + default=", ", + help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください", + ) return parser diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a596a0494..be43847a6 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -105,7 +105,7 @@ from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -378,7 +378,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: InferUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, clip_model: CLIPModel, @@ -2196,6 +2196,7 @@ def main(args): ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) # VAEを読み込む if args.vae is not None: @@ -2352,13 +2353,20 @@ def __getattr__(self, item): vae = sli_vae del sli_vae vae.to(dtype).to(device) + vae.eval() text_encoder.to(dtype).to(device) unet.to(dtype).to(device) + + text_encoder.eval() + unet.eval() + if clip_model is not None: clip_model.to(dtype).to(device) + clip_model.eval() if vgg16_model is not None: vgg16_model.to(dtype).to(device) + vgg16_model.eval() # networkを組み込む if args.network_module: @@ -2501,6 +2509,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -3085,6 +3097,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -3156,10 +3175,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3509,6 +3569,30 @@ def setup_parser() -> argparse.ArgumentParser: # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + return parser diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..ab90fb63b 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -51,6 +51,7 @@ class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False + caption_separator: str = ',', keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d30..e0a026dae 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss diff --git a/library/original_unet.py b/library/original_unet.py index 240b85951..938b0b64c 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -361,6 +361,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class SampleOutput: def __init__(self, sample): self.sample = sample @@ -1130,6 +1147,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1221,6 +1239,7 @@ def forward( # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1519,7 +1538,6 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 @@ -1604,3 +1622,255 @@ def handle_unusual_timesteps(self, sample, timesteps): timesteps = timesteps.expand(sample.shape[0]) return timesteps + + +class InferUNet2DConditionModel: + def __init__(self, original_unet: UNet2DConditionModel): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # override original model's up blocks' forward method + for up_block in self.delegate.up_blocks: + if up_block.__class__.__name__ == "UpBlock2D": + + def resnet_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = resnet_wrapper(self.up_block_forward, up_block) + + elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": + + def cross_attn_up_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in _self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def cross_attn_up_block_forward( + self, + _self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(_self.resnets, _self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. + """ + + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + _self = self.delegate + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**_self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = _self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=_self.dtype) + emb = _self.time_embedding(t_emb) + + # 2. pre-process + sample = _self.conv_in(sample) + + down_block_res_samples = (sample,) + for depth, downsample_block in enumerate(_self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(_self.up_blocks): + is_final_block = i == len(_self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = _self.conv_norm_out(sample) + sample = _self.conv_act(sample) + sample = _self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 26a0af319..babda8ec5 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -24,7 +24,7 @@ import math from types import SimpleNamespace -from typing import Optional +from typing import Any, Optional import torch import torch.utils.checkpoint from torch import nn @@ -266,6 +266,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class GroupNorm32(nn.GroupNorm): def forward(self, x): if self.weight.dtype != torch.float32: @@ -1077,6 +1094,7 @@ def call_module(module, h, emb, context): # h = x.type(self.dtype) h = x + for module in self.input_blocks: h = call_module(module, h, emb, context) hs.append(h) @@ -1093,6 +1111,121 @@ def call_module(module, h, emb, context): return h +class InferSdxlUNet2DConditionModel: + def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + r""" + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + """ + _self = self.delegate + + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = _self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == _self.dtype + emb = emb + _self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for depth, module in enumerate(_self.input_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + # print("downsample", h.shape, self.ds_ratio) + org_dtype = h.dtype + if org_dtype == torch.bfloat16: + h = h.to(torch.float32) + h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(_self.middle_block, h, emb, context) + + for module in _self.output_blocks: + # Deep Shrink + if self.ds_depth_1 is not None: + if hs[-1].shape[-2:] != h.shape[-2:]: + # print("upsample", h.shape, hs[-1].shape) + h = resize_like(h, hs[-1]) + + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + # Deep Shrink: in case of depth 0 + if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: + # print("upsample", h.shape, x.shape) + h = resize_like(h, x) + + h = h.type(x.dtype) + h = call_module(_self.out, h, emb, context) + + return h + + if __name__ == "__main__": import time diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a4..5c4e056d3 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -62,7 +62,7 @@ def cat_h(sliced): return x -def resblock_forward(_self, num_slices, input_tensor, temb): +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): assert _self.upsample is None and _self.downsample is None assert _self.norm1.num_groups == _self.norm2.num_groups assert temb is None diff --git a/library/train_util.py b/library/train_util.py index cc9ac4555..9fb616ed6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -349,6 +349,7 @@ def __init__( image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, + caption_separator: str, keep_tokens: int, color_aug: bool, flip_aug: bool, @@ -365,6 +366,7 @@ def __init__( self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption + self.caption_separator = caption_separator self.keep_tokens = keep_tokens self.color_aug = color_aug self.flip_aug = flip_aug @@ -391,6 +393,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_separator: str, keep_tokens, color_aug, flip_aug, @@ -410,6 +413,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -443,6 +447,7 @@ def __init__( metadata_file: str, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -462,6 +467,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -492,6 +498,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -511,6 +518,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -646,7 +654,7 @@ def process_caption(self, subset: BaseSubset, caption): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: @@ -3105,7 +3113,10 @@ def add_dataset_arguments( # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字" ) parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e2..6357df55d 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ import lora -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-1 +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(args): +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + v2=None, + sdxl=None, + conv_dim=None, + v_parameterization=None, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, +): def str_to_dtype(p): if p == "float": return torch.float @@ -39,44 +53,42 @@ def str_to_dtype(p): return torch.bfloat16 return None - assert args.v2 != args.sdxl or ( - not args.v2 and not args.sdxl - ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if args.v_parameterization is None: - args.v_parameterization = args.v2 + assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + if v_parameterization is None: + v_parameterization = v2 - save_dtype = str_to_dtype(args.save_precision) + save_dtype = str_to_dtype(save_precision) # load models - if not args.sdxl: - print(f"loading original SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + if not sdxl: + print(f"loading original SD model : {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + print(f"loading tuned SD model : {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" ) text_encoders_t = [text_encoder_t1, text_encoder_t2] model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: + if conv_dim is None: kwargs = {} else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -91,9 +103,9 @@ def str_to_dtype(p): diff = module_t.weight - module_o.weight # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diff = diff.float() diffs[lora_name] = diff @@ -120,16 +132,16 @@ def str_to_dtype(p): lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim out_dim, in_dim = mat.size()[0:2] - if args.device: - mat = mat.to(args.device) + if device: + mat = mat.to(device) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim @@ -149,7 +161,7 @@ def str_to_dtype(p): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -178,34 +190,32 @@ def str_to_dtype(p): info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") - dir_name = os.path.dirname(args.save_to) + dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # minimum metadata net_kwargs = {} - if args.conv_dim is not None: - net_kwargs["conv_dim"] = args.conv_dim - net_kwargs["conv_alpha"] = args.conv_dim + if conv_dim is not None: + net_kwargs["conv_dim"] = str(conv_dim) + net_kwargs["conv_alpha"] = str(float(conv_dim)) metadata = { - "ss_v2": str(args.v2), + "ss_v2": str(v2), "ss_base_model_version": model_version, "ss_network_module": "networks.lora", - "ss_network_dim": str(args.dim), - "ss_network_alpha": str(args.dim), + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), "ss_network_args": json.dumps(net_kwargs), } - if not args.no_metadata: - title = os.path.splitext(os.path.basename(args.save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title - ) + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] + sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) metadata.update(sai_metadata) - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + lora_network_save.save_weights(save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument( "--v_parameterization", - type=bool, + action="store_true", default=None, help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", ) @@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser: "--model_org", type=str, default=None, + required=True, help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", ) parser.add_argument( "--model_tuned", type=str, default=None, + required=True, help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", + ) + parser.add_argument( + "--min_diff", + type=float, + default=0.01, + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", + ) parser.add_argument( "--no_metadata", action="store_true", @@ -264,4 +293,4 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() - svd(args) + svd(**vars(args)) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c31ae0072..ab5399842 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -57,7 +57,7 @@ import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite @@ -290,7 +290,7 @@ def __init__( vae: AutoencoderKL, text_encoders: List[CLIPTextModel], tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, + unet: InferSdxlUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, ): @@ -328,7 +328,7 @@ def __init__( self.vae = vae self.text_encoders = text_encoders self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet + self.unet: InferSdxlUNet2DConditionModel = unet self.scheduler = scheduler self.safety_checker = None @@ -504,7 +504,8 @@ def __call__( uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -567,9 +568,11 @@ def __call__( text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) @@ -1371,6 +1374,7 @@ def main(args): (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1526,10 +1530,14 @@ def __getattr__(self, item): print("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) + vae.eval() text_encoder1.to(dtype).to(device) text_encoder2.to(dtype).to(device) unet.to(dtype).to(device) + text_encoder1.eval() + text_encoder2.eval() + unet.eval() # networkを組み込む if args.network_module: @@ -1696,6 +1704,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2286,6 +2298,13 @@ def scale_and_round(x): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -2393,10 +2412,51 @@ def scale_and_round(x): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2734,6 +2794,31 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697c..44447d1f0 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10eaa..91cbacc6a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -430,7 +430,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..e0118d1c5 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -449,7 +449,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 7fbbc18ac..966999dfb 100644 --- a/train_db.py +++ b/train_db.py @@ -342,7 +342,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index d50916b74..1cbed2e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -812,7 +812,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..45a437b91 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -578,7 +578,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672f..f77ad2eb2 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -469,7 +469,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: