From a4161d04a2d3be3fc5d58f54fd2a96c5e29acafd Mon Sep 17 00:00:00 2001 From: Linaqruf Date: Mon, 2 Jan 2023 12:03:57 +0700 Subject: [PATCH] feat: integrating kohya-trained.ipynb to new repo format --- convert_diffusers20_original_sd/model_util.py | 1167 ------------- diffuser_fine_tuning/requirements.txt | 8 - .../fine_tune.py => fine_tune.py | 4 +- finetune/blip/blip.py | 240 +++ finetune/blip/med.py | 955 +++++++++++ finetune/blip/med_config.json | 22 + finetune/blip/vit.py | 305 ++++ .../clean_captions_and_tags.py | 0 .../hypernetwork_nai.py | 0 .../make_captions.py | 25 +- .../merge_captions_to_metadata.py | 0 .../merge_dd_tags_to_metadata.py | 0 .../prepare_buckets_latents.py | 8 +- .../tag_images_by_wd14_tagger.py | 0 ...n_img_diffusers.py => gen_img_diffusers.py | 73 +- gen_img_diffusers/model_util.py | 1182 -------------- kohya-dreambooth.ipynb | 660 ++++++-- kohya-trainer.ipynb | 401 +++-- library/__init__.py | 0 .../model_util.py | 12 +- networks/lora.py | 190 +++ networks/merge_lora.py | 159 ++ requirements.txt | 23 + script/detect_face_rotate_v3.py | 235 --- script/requirements.txt | 11 - script/tag_images_by_wd14_tagger.py | 107 -- setup.py | 3 + {script => tools}/code-snippet.ipynb | 0 .../convert_diffusers20_original_sd.py | 4 +- tools/detect_face_rotate.py | 239 +++ {script => tools}/merge_block_weighted.py | 0 {script => tools}/merge_vae.py | 0 .../train_db_fixed.py => train_db.py | 9 +- train_db_fixed/model_util.py | 1182 -------------- train_network.py | 1453 +++++++++++++++++ 35 files changed, 4390 insertions(+), 4287 deletions(-) delete mode 100644 convert_diffusers20_original_sd/model_util.py delete mode 100644 diffuser_fine_tuning/requirements.txt rename diffuser_fine_tuning/fine_tune.py => fine_tune.py (99%) create mode 100644 finetune/blip/blip.py create mode 100644 finetune/blip/med.py create mode 100644 finetune/blip/med_config.json create mode 100644 finetune/blip/vit.py rename {diffuser_fine_tuning => finetune}/clean_captions_and_tags.py (100%) rename {diffuser_fine_tuning => finetune}/hypernetwork_nai.py (100%) rename {diffuser_fine_tuning => finetune}/make_captions.py (83%) rename {diffuser_fine_tuning => finetune}/merge_captions_to_metadata.py (100%) rename {diffuser_fine_tuning => finetune}/merge_dd_tags_to_metadata.py (100%) rename {diffuser_fine_tuning => finetune}/prepare_buckets_latents.py (94%) rename {diffuser_fine_tuning => finetune}/tag_images_by_wd14_tagger.py (100%) rename gen_img_diffusers/gen_img_diffusers.py => gen_img_diffusers.py (98%) delete mode 100644 gen_img_diffusers/model_util.py create mode 100644 library/__init__.py rename {diffuser_fine_tuning => library}/model_util.py (99%) create mode 100644 networks/lora.py create mode 100644 networks/merge_lora.py create mode 100644 requirements.txt delete mode 100644 script/detect_face_rotate_v3.py delete mode 100644 script/requirements.txt delete mode 100644 script/tag_images_by_wd14_tagger.py create mode 100644 setup.py rename {script => tools}/code-snippet.ipynb (100%) rename {convert_diffusers20_original_sd => tools}/convert_diffusers20_original_sd.py (98%) create mode 100644 tools/detect_face_rotate.py rename {script => tools}/merge_block_weighted.py (100%) rename {script => tools}/merge_vae.py (100%) rename train_db_fixed/train_db_fixed.py => train_db.py (99%) delete mode 100644 train_db_fixed/model_util.py create mode 100644 train_network.py diff --git a/convert_diffusers20_original_sd/model_util.py b/convert_diffusers20_original_sd/model_util.py deleted file mode 100644 index 9610c90e..00000000 --- a/convert_diffusers20_original_sd/model_util.py +++ /dev/null @@ -1,1167 +0,0 @@ -# v1: split from train_db_fixed.py. -# v2: support safetensors - -import math -import os -import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from safetensors.torch import load_file, save_file - -# DiffUsers版StableDiffusionのモデルパラメータ -NUM_TRAIN_TIMESTEPS = 1000 -BETA_START = 0.00085 -BETA_END = 0.0120 - -UNET_PARAMS_MODEL_CHANNELS = 320 -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 32 # unused -UNET_PARAMS_IN_CHANNELS = 4 -UNET_PARAMS_OUT_CHANNELS = 4 -UNET_PARAMS_NUM_RES_BLOCKS = 2 -UNET_PARAMS_CONTEXT_DIM = 768 -UNET_PARAMS_NUM_HEADS = 8 - -VAE_PARAMS_Z_CHANNELS = 4 -VAE_PARAMS_RESOLUTION = 256 -VAE_PARAMS_IN_CHANNELS = 3 -VAE_PARAMS_OUT_CH = 3 -VAE_PARAMS_CH = 128 -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] -VAE_PARAMS_NUM_RES_BLOCKS = 2 - -# V2 -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] -V2_UNET_PARAMS_CONTEXT_DIM = 1024 - - -# region StableDiffusion->Diffusersの変換コード -# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) - - -def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) - - return config - - -def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config - - -def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict - - -def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) - return new_sd - -# endregion - - -# region Diffusers->StableDiffusion の変換コード -# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - -def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - - -def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict - - -# ================# -# VAE Conversion # -# ================# - -def reshape_weight_for_sd(w): - # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) - - -def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict - - -# endregion - -# region 自作のモデル読み書きなど - -def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' - - -def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, "cpu") - else: - checkpoint = torch.load(ckpt_path, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - checkpoint = None - - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) - - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - return checkpoint, state_dict - - -# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) - - return text_model, vae, unet - - -def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) - for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key] - - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) - - return new_sd - - -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), - tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = "first_stage_model." - - -def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - -# endregion - - -def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) - - resos = set() - - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) - - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) - - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) - - size += divisible - - resos = list(resos) - resos.sort() - - aspect_ratios = [w / h for w, h in resos] - return resos, aspect_ratios - - -if __name__ == '__main__': - resos, aspect_ratios = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - print(aspect_ratios) - - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) diff --git a/diffuser_fine_tuning/requirements.txt b/diffuser_fine_tuning/requirements.txt deleted file mode 100644 index 673de003..00000000 --- a/diffuser_fine_tuning/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -accelerate -transformers>=4.21.0 -ftfy -albumentations -opencv-python -einops -pytorch_lightning -safetensors diff --git a/diffuser_fine_tuning/fine_tune.py b/fine_tune.py similarity index 99% rename from diffuser_fine_tuning/fine_tune.py rename to fine_tune.py index 66e3c1f1..49d84dcc 100644 --- a/diffuser_fine_tuning/fine_tune.py +++ b/fine_tune.py @@ -50,7 +50,7 @@ from einops import rearrange from torch import einsum -import model_util +import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -540,7 +540,7 @@ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): # v4で更新:clip_sample=Falseに # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ + # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) diff --git a/finetune/blip/blip.py b/finetune/blip/blip.py new file mode 100644 index 00000000..7851fb08 --- /dev/null +++ b/finetune/blip/blip.py @@ -0,0 +1,240 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import warnings +warnings.filterwarnings("ignore") + +# from models.vit import VisionTransformer, interpolate_pos_embed +# from models.med import BertConfig, BertModel, BertLMHeadModel +from blip.vit import VisionTransformer, interpolate_pos_embed +from blip.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +import os +from urllib.parse import urlparse +from timm.models.hub import download_cached_file + +class BLIP_Base(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + + def forward(self, image, caption, mode): + + assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" + text = self.tokenizer(caption, return_tensors="pt").to(image.device) + + if mode=='image': + # return image features + image_embeds = self.visual_encoder(image) + return image_embeds + + elif mode=='text': + # return text features + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + return text_output.last_hidden_state + + elif mode=='multimodal': + # return multimodel features + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text.input_ids[:,0] = self.tokenizer.enc_token_id + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + return output.last_hidden_state + + + +class BLIP_Decoder(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + prompt = 'a picture of ', + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel(config=med_config) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 + + + def forward(self, image, caption): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) + + text.input_ids[:,0] = self.tokenizer.bos_token_id + + decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) + decoder_targets[:,:self.prompt_length] = -100 + + decoder_output = self.text_decoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + loss_lm = decoder_output.loss + + return loss_lm + + def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): + image_embeds = self.visual_encoder(image) + + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} + + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) + input_ids[:,0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + #nucleus sampling + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + #beam search + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + return captions + + +def blip_decoder(pretrained='',**kwargs): + model = BLIP_Decoder(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def blip_feature_extractor(pretrained='',**kwargs): + model = BLIP_Base(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token':'[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit=='base': + vision_width = 768 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, + num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate + ) + elif vit=='large': + vision_width = 1024 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, + num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate + ) + return visual_encoder, vision_width + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], + model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape!=model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + diff --git a/finetune/blip/med.py b/finetune/blip/med.py new file mode 100644 index 00000000..7b00a354 --- /dev/null +++ b/finetune/blip/med.py @@ -0,0 +1,955 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/finetune/blip/med_config.json b/finetune/blip/med_config.json new file mode 100644 index 00000000..dc12b99c --- /dev/null +++ b/finetune/blip/med_config.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true + } + \ No newline at end of file diff --git a/finetune/blip/vit.py b/finetune/blip/vit.py new file mode 100644 index 00000000..cec3d8e0 --- /dev/null +++ b/finetune/blip/vit.py @@ -0,0 +1,305 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/diffuser_fine_tuning/clean_captions_and_tags.py b/finetune/clean_captions_and_tags.py similarity index 100% rename from diffuser_fine_tuning/clean_captions_and_tags.py rename to finetune/clean_captions_and_tags.py diff --git a/diffuser_fine_tuning/hypernetwork_nai.py b/finetune/hypernetwork_nai.py similarity index 100% rename from diffuser_fine_tuning/hypernetwork_nai.py rename to finetune/hypernetwork_nai.py diff --git a/diffuser_fine_tuning/make_captions.py b/finetune/make_captions.py similarity index 83% rename from diffuser_fine_tuning/make_captions.py rename to finetune/make_captions.py index dd70b1b1..b02420bd 100644 --- a/diffuser_fine_tuning/make_captions.py +++ b/finetune/make_captions.py @@ -1,10 +1,8 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - import argparse import glob import os import json +import random from PIL import Image from tqdm import tqdm @@ -12,20 +10,34 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from models.blip import blip_decoder +from blip.blip import blip_decoder # from Salesforce_BLIP.models.blip import blip_decoder DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def main(args): + # fix the seed for reproducibility + seed = args.seed # + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + if not os.path.exists("blip"): + args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path + + cwd = os.getcwd() + print('Current Working Directory is: ', cwd) + os.chdir('finetune') + + print(f"load images from {args.train_data_dir}") image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) print(f"found {len(image_paths)} images.") print(f"loading BLIP caption: {args.caption_weights}") image_size = 384 - model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large') + model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json") model.eval() model = model.to(DEVICE) print("BLIP loaded") @@ -75,7 +87,7 @@ def run_batch(path_imgs): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("caption_weights", type=str, + parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)") parser.add_argument("--caption_extention", type=str, default=None, help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") @@ -87,6 +99,7 @@ def run_batch(path_imgs): parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") + parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument("--debug", action="store_true", help="debug mode") args = parser.parse_args() diff --git a/diffuser_fine_tuning/merge_captions_to_metadata.py b/finetune/merge_captions_to_metadata.py similarity index 100% rename from diffuser_fine_tuning/merge_captions_to_metadata.py rename to finetune/merge_captions_to_metadata.py diff --git a/diffuser_fine_tuning/merge_dd_tags_to_metadata.py b/finetune/merge_dd_tags_to_metadata.py similarity index 100% rename from diffuser_fine_tuning/merge_dd_tags_to_metadata.py rename to finetune/merge_dd_tags_to_metadata.py diff --git a/diffuser_fine_tuning/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py similarity index 94% rename from diffuser_fine_tuning/prepare_buckets_latents.py rename to finetune/prepare_buckets_latents.py index f4c6a371..00f847a1 100644 --- a/diffuser_fine_tuning/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -14,7 +14,7 @@ import torch from torchvision import transforms -import model_util +import library.model_util as model_util DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -130,14 +130,16 @@ def main(args): latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype) for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name), latent) # flip if args.flip_aug: latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない for (image_key, reso, _), latent in zip(bucket, latents): - np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent) + npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key + np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent) bucket.clear() diff --git a/diffuser_fine_tuning/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py similarity index 100% rename from diffuser_fine_tuning/tag_images_by_wd14_tagger.py rename to finetune/tag_images_by_wd14_tagger.py diff --git a/gen_img_diffusers/gen_img_diffusers.py b/gen_img_diffusers.py similarity index 98% rename from gen_img_diffusers/gen_img_diffusers.py rename to gen_img_diffusers.py index 7532faa5..75f14afa 100644 --- a/gen_img_diffusers/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1,14 +1,15 @@ # txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc... # v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix -# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue hypernetwork_mul is not working +# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working # v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode # v5: fix clip_sample=True for scheduler, add VGG guidance # v6: refactor to use model util, load VAE without vae folder, support safe tensors # v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size, # Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction) # v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes, -# sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1 +# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1 +# v10: fix app crashes when different image size in prompts # Copyright 2022 kohya_ss @kohya_ss # @@ -112,7 +113,7 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo -import model_util +import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -332,7 +333,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use Hypernetwork and FlashAttention") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention") flash_func = FlashAttentionFunction def forward_flash_attn(self, x, context=None, mask=None): @@ -372,7 +373,7 @@ def forward_flash_attn(self, x, context=None, mask=None): def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use Hypernetwork and xformers") + print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers") try: import xformers.ops except ImportError: @@ -1866,25 +1867,6 @@ def main(args): if not args.diffusers_xformers: replace_unet_modules(unet, not args.xformers, args.xformers) - # hypernetworkを組み込む - if args.hypernetwork_module is not None: - assert not args.diffusers_xformers, "cannot use hypernetwork with diffusers_xformers / diffusers_xformers指定時はHypernetworkは利用できません" - - print("import hypernetwork module:", args.hypernetwork_module) - hyp_module = importlib.import_module(args.hypernetwork_module) - - hypernetwork = hyp_module.Hypernetwork(args.hypernetwork_mul) - - print("load hypernetwork weights from:", args.hypernetwork_weights) - hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') - success = hypernetwork.load_from_state_dict(hyp_sd) - assert success, "hypernetwork weights loading failed." - - if args.opt_channels_last: - hypernetwork.to(memory_format=torch.channels_last) - else: - hypernetwork = None - # tokenizerを読み込む print("loading tokenizer") if use_stable_diffusion_format: @@ -1999,10 +1981,27 @@ def __getattr__(self, item): if vgg16_model is not None: vgg16_model.to(dtype).to(device) - if hypernetwork is not None: - hypernetwork.to(dtype).to(device) - print("apply hypernetwork") - hypernetwork.apply_to_diffusers(vae, text_encoder, unet) + # networkを組み込む + if args.network_module is not None: + # assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" + + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) + if network is None: + return + + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + else: + network = None if args.opt_channels_last: print(f"set optimizing: channels last") @@ -2011,8 +2010,8 @@ def __getattr__(self, item): unet.to(memory_format=torch.channels_last) if clip_model is not None: clip_model.to(memory_format=torch.channels_last) - if hypernetwork is not None: - hypernetwork.to(memory_format=torch.channels_last) + if network is not None: + network.to(memory_format=torch.channels_last) if vgg16_model is not None: vgg16_model.to(memory_format=torch.channels_last) @@ -2424,7 +2423,7 @@ def process_batch(batch, highres_fix, highres_1st=False): b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), (width, height, steps, scale, strength)) if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? - process_batch(batch_data) + process_batch(batch_data, highres_fix) batch_data.clear() batch_data.append(b1) @@ -2487,12 +2486,14 @@ def process_batch(batch, highres_fix, highres_1st=False): parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannles lastを指定し最適化する') - parser.add_argument("--hypernetwork_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') - parser.add_argument("--hypernetwork_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--hypernetwork_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + help='set channels last option to model / モデルにchannels lastを指定し最適化する') + parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') + parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + parser.add_argument("--network_dim", type=int, default=None, + help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') diff --git a/gen_img_diffusers/model_util.py b/gen_img_diffusers/model_util.py deleted file mode 100644 index f3453025..00000000 --- a/gen_img_diffusers/model_util.py +++ /dev/null @@ -1,1182 +0,0 @@ -# v1: split from train_db_fixed.py. -# v2: support safetensors - -import math -import os -import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from safetensors.torch import load_file, save_file - -# DiffUsers版StableDiffusionのモデルパラメータ -NUM_TRAIN_TIMESTEPS = 1000 -BETA_START = 0.00085 -BETA_END = 0.0120 - -UNET_PARAMS_MODEL_CHANNELS = 320 -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 32 # unused -UNET_PARAMS_IN_CHANNELS = 4 -UNET_PARAMS_OUT_CHANNELS = 4 -UNET_PARAMS_NUM_RES_BLOCKS = 2 -UNET_PARAMS_CONTEXT_DIM = 768 -UNET_PARAMS_NUM_HEADS = 8 - -VAE_PARAMS_Z_CHANNELS = 4 -VAE_PARAMS_RESOLUTION = 256 -VAE_PARAMS_IN_CHANNELS = 3 -VAE_PARAMS_OUT_CH = 3 -VAE_PARAMS_CH = 128 -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] -VAE_PARAMS_NUM_RES_BLOCKS = 2 - -# V2 -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] -V2_UNET_PARAMS_CONTEXT_DIM = 1024 - -# Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" -DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" - - -# region StableDiffusion->Diffusersの変換コード -# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) - - -def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) - - return config - - -def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config - - -def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict - - -def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) - return new_sd - -# endregion - - -# region Diffusers->StableDiffusion の変換コード -# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - -def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - - -def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict - - -# ================# -# VAE Conversion # -# ================# - -def reshape_weight_for_sd(w): - # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) - - -def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict - - -# endregion - -# region 自作のモデル読み書きなど - -def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' - - -def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, "cpu") - else: - checkpoint = torch.load(ckpt_path, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - checkpoint = None - - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) - - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - return checkpoint, state_dict - - -# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) - - return text_model, vae, unet - - -def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) - for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる - - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) - - return new_sd - - -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = "first_stage_model." - - -def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - -# endregion - - -def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) - - resos = set() - - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) - - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) - - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) - - size += divisible - - resos = list(resos) - resos.sort() - - aspect_ratios = [w / h for w, h in resos] - return resos, aspect_ratios - - -if __name__ == '__main__': - resos, aspect_ratios = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - print(aspect_ratios) - - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) diff --git a/kohya-dreambooth.ipynb b/kohya-dreambooth.ipynb index 7186749a..c54fbc6e 100644 --- a/kohya-dreambooth.ipynb +++ b/kohya-dreambooth.ipynb @@ -40,7 +40,6 @@ "cell_type": "markdown", "source": [ "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nee3ed1649fb6)
\n", - "Adapted again from [bmaltais's Kohya Archive](https://github.com/bmaltais/kohya_ss)
\n", "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
\n", "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-dreambooth.ipynb)\n", "\n", @@ -130,47 +129,190 @@ "execution_count": null, "outputs": [] }, + { + "cell_type": "code", + "source": [ + "#@title Install Pre-trained Model \n", + "%cd /content/kohya-trainer\n", + "import os\n", + "\n", + "# Check if directory exists\n", + "if not os.path.exists('checkpoint'):\n", + " # Create directory if it doesn't exist\n", + " os.makedirs('checkpoint')\n", + "\n", + "#@title Install Pre-trained Model \n", + "\n", + "installModels=[]\n", + "installVae= []\n", + "installVaeArgs = []\n", + "\n", + "#@markdown ### Available Model\n", + "#@markdown Select one of available pretrained model to download:\n", + "modelUrl = [\"\", \\\n", + " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \\\n", + " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/modelsfw-pruned.ckpt\", \\\n", + " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp16.ckpt\", \\\n", + " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp32.ckpt\", \\\n", + " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned.ckpt\", \\\n", + " \"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt\", \\\n", + " \"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt\", \\\n", + " \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float32.ckpt\"]\n", + "modelList = [\"\", \\\n", + " \"Animefull-final-pruned\", \\\n", + " \"Animesfw-final-pruned\", \\\n", + " \"Anything-V3.0-pruned-fp16\", \\\n", + " \"Anything-V3.0-pruned-fp32\", \\\n", + " \"Anything-V3.0-pruned\", \\\n", + " \"Stable-Diffusion-v1-4\", \\\n", + " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n", + " \"Waifu-Diffusion-v1-3-fp32\"]\n", + "modelName = \"Anything-V3.0-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n", + "\n", + "#@markdown ### Custom model\n", + "#@markdown The model URL should be a direct download link.\n", + "customName = \"\" #@param {'type': 'string'}\n", + "customUrl = \"\"#@param {'type': 'string'}\n", + "\n", + "\n", + "# Check if user has specified a custom model\n", + "if customName != \"\" and customUrl != \"\":\n", + " # Add custom model to list of models to install\n", + " installModels.append((customName, customUrl))\n", + "\n", + "# Check if user has selected a model\n", + "if modelName != \"\":\n", + " # Map selected model to URL\n", + " installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n", + "\n", + "#@markdown Select one of the VAEs to download, select `none` for not download VAE:\n", + "vaeUrl = [\"\", \\\n", + " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animevae/animevae.pt\", \\\n", + " \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\"]\n", + "vaeList = [\"none\", \\\n", + " \"anime.vae.pt\", \\\n", + " \"waifudiffusion.vae.pt\"]\n", + "vaeName = \"anime.vae.pt\" #@param [\"none\",\"anime.vae.pt\",\"waifudiffusion.vae.pt\"]\n", + "vae_args = [\"none\", \\\n", + " \"--vae-path /content/stable-diffusion-webui/models/Stable-diffusion/anime.vae.pt\", \\\n", + " \"--vae-path /content/stable-diffusion-webui/models/Stable-diffusion/waifudiffusion.vae.pt\"]\n", + "\n", + "installVae.append((vaeName, vaeUrl[vaeList.index(vaeName)]))\n", + "installVaeArgs.append((vae_args[vaeList.index(vaeName)]))\n", + "\n", + "def install_aria():\n", + " # Install aria2 if it is not already installed\n", + " if not os.path.exists('/usr/bin/aria2c'):\n", + " !apt install -y -qq aria2\n", + "\n", + "def install(checkpoint_name, url):\n", + " if url.endswith(\".ckpt\"):\n", + " dst = \"/content/kohya-trainer/checkpoint/\" + str(checkpoint_name) + \".ckpt\"\n", + " elif url.endswith(\".safetensors\"):\n", + " dst = \"/content/kohya-trainer/checkpoint/\" + str(checkpoint_name) + \".safetensors\"\n", + " elif url.endswith(\".pt\"):\n", + " dst = \"/content/kohya-trainer/checkpoint/\" + str(checkpoint_name)\n", + " else:\n", + " dst = \"/content/kohya-trainer/checkpoint/\" + str(checkpoint_name) + \".ckpt\"\n", + "\n", + " if url.startswith(\"https://drive.google.com\"):\n", + " # Use gdown to download file from Google Drive\n", + " !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n", + " elif url.startswith(\"magnet:?\"):\n", + " install_aria()\n", + " # Use aria2c to download file from magnet link\n", + " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o {dst} \"{url}\"\n", + " else:\n", + " user_token = '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'\n", + " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n", + " # Use wget to download file from URL\n", + " !wget -c --header={user_header} \"{url}\" -O {dst}\n", + "\n", + "def install_checkpoint():\n", + " # Iterate through list of models to install\n", + " for model in installModels:\n", + " # Call install function for each model\n", + " install(model[0], model[1])\n", + "\n", + " if vaeName != \"none\":\n", + " for vae in installVae:\n", + " install(vae[0], vae[1])\n", + " else:\n", + " pass\n", + "\n", + "# Call install_checkpoint function to download all models in the list\n", + "install_checkpoint()\n" + ], + "metadata": { + "id": "SoucgZQ6jgPQ", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", + "metadata": { + "id": "M0fzmhtywk_u" + }, + "source": [ + "# Prepare Cloud Storage (Huggingface/GDrive)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "cwIJdhEcwk_u" + }, + "outputs": [], "source": [ - "## Folders configuration\n", - "\n", - "Refer to the note to understand how to create the folder structure. In short it should look like:\n", - "\n", - "```\n", - "\n", - "|- \n", - " |- _\n", - "|- \n", - " |- _ \n", - "```\n", - "\n", - "Example for `asd dog` where `asd` is the token word and `dog` is the class. In this example the regularization `dog` class images contained in the folder will be repeated only 1 time and the `asd dog` images will be repeated 20 times:\n", - "\n", - "```\n", - "my_asd_dog_dreambooth\n", - "|- reg_dog\n", - " |- 1_dog\n", - " `- reg_image_1.png\n", - " `- reg_image_2.png\n", - " ...\n", - " `- reg_image_256.png\n", - "|- train_dog\n", - " |- 20_asd dog\n", - " `- dog1.png\n", - " ...\n", - " `- dog8.png\n", - "```" + "#@title Login to Huggingface hub\n", + "\n", + "#@markdown ## Instructions:\n", + "#@markdown 1. Of course, you need a Huggingface account first.\n", + "#@markdown 2. To create a huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n", + "\n", + "%cd /content/kohya-trainer\n", + "\n", + "from huggingface_hub import login\n", + "login()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "jVgHUUK_wk_v" + }, + "outputs": [], + "source": [ + "#@title Mount Google Drive\n", + "\n", + "from google.colab import drive\n", + "\n", + "mount_drive = True #@param {'type':'boolean'}\n", + "\n", + "if mount_drive:\n", + " drive.mount('/content/drive')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "#Preparing Datasets" ], "metadata": { - "id": "En9UUwGNMRMM" + "id": "Pz9A2bu1Cq73" } }, { "cell_type": "code", "source": [ - "#@title Create train and reg folder based on description above\n", - "\n", + "#@title Create train and reg folder \n", "# Import the os and shutil modules\n", "import os\n", "import shutil\n", @@ -191,7 +333,7 @@ "\n", "#@markdown ### Define the reg_folder variable\n", "reg_count = 1 #@param {type: \"integer\"}\n", - "reg_class =\"kasakai_hikaru\" #@param {type: \"string\"}\n", + "reg_class =\"1girl\" #@param {type: \"string\"}\n", "reg_folder = str(reg_count) + \"_\" + reg_class\n", "\n", "# Define the reg_directory variable\n", @@ -217,10 +359,10 @@ " os.mkdir(reg_folder_directory)\n", "\n", "#@markdown ### Define the train_folder variable\n", - "train_count = 3300 #@param {type: \"integer\"}\n", - "train_token = \"sls\" #@param {type: \"string\"}\n", - "train_class = \"kasakai_hikaru\" #@param {type: \"string\"}\n", - "train_folder = str(train_count) + \"_\" + train_token + \"_\" + train_class\n", + "train_count = 20 #@param {type: \"integer\"}\n", + "train_token = \"makora\" #@param {type: \"string\"}\n", + "train_class = \"1girl\" #@param {type: \"string\"}\n", + "train_folder = str(train_count) + \"_\" + train_token + \" \" + train_class\n", "\n", "# Define the train_directory variable\n", "train_directory = f\"{dreambooth_directory}/train_{train_class}\"\n", @@ -253,46 +395,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "markdown", - "source": [ - "#Preparing Datasets" - ], - "metadata": { - "id": "Pz9A2bu1Cq73" - } - }, - { - "cell_type": "code", - "source": [ - "\n", - "#@title Prepare Regularization Images\n", - "#@markdown Download regularization images provided by community\n", - "category = \"waifu-regularization-3.3k\" #@param [\"\", \"waifu-regularization-3.3k\", \"husbando-regularization-3.5k\"]\n", - "#@markdown Or you can use the file manager on the left panel to upload (drag and drop) to `reg_images` folder (it uploads faster)\n", - "def reg_images(url, name):\n", - " user_token = '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'\n", - " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n", - " !wget -c --header={user_header} \"{url}\" -O /content/dreambooth/reg_{reg_class}/{reg_folder}/{name}.zip\n", - "\n", - "if category != \"\":\n", - " if category == \"waifu-regularization-3.3k\":\n", - " reg_images(\"https://huggingface.co/datasets/waifu-research-department/regularization/resolve/main/waifu-regularization-3.3k.zip\", \"waifu-regularization-3.3k\")\n", - " !unzip /content/dreambooth/reg_{reg_class}/{reg_folder}/waifu-regularization-3.3k.zip -d /content/dreambooth/reg_{reg_class}/{reg_folder}\n", - " !rm /content/dreambooth/reg_{reg_class}/{reg_folder}/waifu-regularization-3.3k.zip\n", - " else:\n", - " reg_images(\"https://huggingface.co/datasets/waifu-research-department/regularization/resolve/main/husbando-regularization-3.5k.zip\", \"husbando-regularization-3.5k\")\n", - " !unzip /content/dreambooth/reg_{reg_class}/{reg_folder}/husbando-regularization-3.5k.zip -d /content/dreambooth/reg_{reg_class}/{reg_folder}\n", - " !rm /content/dreambooth/reg_{reg_class}/{reg_folder}/husbando-regularization-3.5k.zip\n", - " \n" - ], - "metadata": { - "cellView": "form", - "id": "R9JZSOuSzXe_" - }, - "execution_count": null, - "outputs": [] - }, { "cell_type": "code", "source": [ @@ -330,87 +432,227 @@ { "cell_type": "code", "source": [ - "#@title Install Pre-trained Model \n", + "#@title Download compressed (.zip) dataset (Optional)\n", + "\n", + "\n", + "#@markdown ### Define Download Parameter\n", + "datasets_url = \"https://huggingface.co/datasets/Linaqruf/dreambooth-dataset/resolve/main/makora-dreambooth.zip\" #@param {'type': 'string'}\n", + "dataset_dst = '/content/dreambooth.zip' #@param{'type':'string'}\n", + "#@markdown ### Define Auto-Unzip Parameter\n", + "extract_to = '/content/makora-dreambooth' #@param{'type':'string'}\n", + "unzip_module = \"use_7zip\" #@param [\"use_unzip\",\"use_7zip\",\"use_Zipfile\"]\n", + "\n", + "def download_and_unzip_dataset(url, zip_file, extract_to, unzip_module):\n", + " try:\n", + " # Download dataset\n", + " if url.startswith(\"https://drive.google.com\"):\n", + " # Use gdown to download file from Google Drive\n", + " !gdown -o \"{zip_file}\" --fuzzy \"{url}\"\n", + " elif url.startswith(\"magnet:?\"):\n", + " install_aria()\n", + " # Use aria2c to download file from magnet link\n", + " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o \"{zip_file}\" \"{url}\"\n", + " else:\n", + " user_token = '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'\n", + " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n", + " # Use wget to download file from URL\n", + " !wget -c -O \"{zip_file}\" --header={user_header} \"{url}\"\n", + "\n", + " # Unzip dataset\n", + " if unzip_module == \"use_7zip\":\n", + " !7z x $zip_file -o$extract_to\n", + " elif unzip_module == \"use_unzip\":\n", + " !unzip $zip_file -d $extract_to\n", + " elif unzip_module == \"use_Zipfile\":\n", + " import zipfile\n", + " with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n", + " zip_ref.extractall(extract_to)\n", + " except Exception as e:\n", + " print(\"An error occurred while downloading or unzipping the file:\", e)\n", + "\n", + "# Call download_and_unzip_dataset function\n", + "download_and_unzip_dataset(datasets_url, dataset_dst, extract_to, unzip_module)" + ], + "metadata": { + "cellView": "form", + "id": "Po6HdriyM7oC" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Prepare Regularization Image\n", + "prompt = \"1girl, solo\" #@param {type: \"string\"}\n", + "negative = \"lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\" #@param {type: \"string\"}\n", + "model = \"/content/kohya-trainer/checkpoint/Anything-V3.0-pruned.ckpt\" #@param {type: \"string\"}\n", + "vae = \"/content/kohya-trainer/checkpoint/anime.vae.pt\" #@param {type: \"string\"}\n", + "reg_path = \"/content/makora-dreambooth/dreambooth/reg_1girl\" #@param {type: \"string\"}\n", + "scale = 12 #@param {type: \"slider\", min: 1, max: 40}\n", + "sampler = \"ddim\" #@param [\"ddim\", \"pndm\", \"lms\", \"euler\", \"euler_a\", \"heun\", \"dpm_2\", \"dpm_2_a\", \"dpmsolver\",\"dpmsolver++\", \"dpmsingle\", \"k_lms\", \"k_euler\", \"k_euler_a\", \"k_dpm_2\", \"k_dpm_2_a\"]\n", + "steps = 20 #@param {type: \"slider\", min: 1, max: 100}\n", + "precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n", + "width = 768 #@param {type: \"integer\"}\n", + "height = 768 #@param {type: \"integer\"}\n", + "batch_size = 4 #@param {type: \"integer\"}\n", + "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 40}\n", + "\n", + "train_num_images = sum(os.path.isfile(os.path.join(train_folder_directory, name)) for name in os.listdir(train_folder_directory))\n", + "print(\"You have \" + str(train_num_images) + \" training data.\")\n", + "\n", + "if reg_count == 0:\n", + " reg_num_images = 0\n", + "elif train_num_images > 0:\n", + " reg_num_images = sum(os.path.isfile(os.path.join(reg_folder_directory, name)) for name in os.listdir(reg_folder_directory))\n", + " print(\"You have \" + str(reg_num_images) + \" regularization images.\")\n", + " reg_num_images = (train_count * train_num_images) // reg_count - reg_num_images\n", + " print(\"You need \" + str(reg_num_images) + \" regularization images.\")\n", + " print(\"This process will generate \" + str(reg_num_images) + \" images left and place them in your regularization image path.\")\n", + "\n", + "!python /content/kohya-trainer/gen_img_diffusers/gen_img_diffusers.py \\\n", + " --ckpt {model} \\\n", + " --outdir {reg_path} \\\n", + " --xformers \\\n", + " --vae {vae} \\\n", + " --{precision} \\\n", + " --W {width} \\\n", + " --H {height} \\\n", + " --clip_skip {clip_skip} \\\n", + " --scale {scale} \\\n", + " --sampler {sampler} \\\n", + " --steps {steps} \\\n", + " --max_embeddings_multiples 3 \\\n", + " --batch_size {batch_size} \\\n", + " --images_per_prompt {reg_num_images} \\\n", + " --prompt \"{prompt} --n {negative}\"\n", + "\n" + ], + "metadata": { + "cellView": "form", + "id": "HHjTV_1HKRbR" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Preprocessing Training Data" + ], + "metadata": { + "id": "d-9eJhB_QnBL" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Dataset Labeling\n", "%cd /content/kohya-trainer\n", + "\n", + "import shutil\n", "import os\n", "\n", - "# Check if directory exists\n", - "if not os.path.exists('checkpoint'):\n", - " # Create directory if it doesn't exist\n", - " os.makedirs('checkpoint')\n", + "use_blip_captioning = False #@param {type :'boolean'}\n", + "use_wd_1_4_tagger = True #@param {type :'boolean'}\n", + "\n", + "global_batch_size = 8 #@param {type:'integer'}\n", + "if use_blip_captioning:\n", + " def clone_and_prepare_spaces():\n", + " \"\"\"\n", + " Clones the Spaces repository, downloads the BLIP model weights, and moves the make_captions.py script to the BLIP directory.\n", + " \"\"\"\n", + " # Constants\n", + " BLIP_WEIGHT_SOURCE_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'\n", + " BLIP_WEIGHT_DESTINATION_PATH = '/content/kohya-trainer/BLIP/model_large_caption.pth'\n", + " MAKE_CAPTION_SOURCE_PATH = '/content/kohya-trainer/diffuser_fine_tuning/make_captions.py'\n", + " MAKE_CAPTION_DESTINATION_PATH = '/content/kohya-trainer/BLIP/make_captions.py'\n", + "\n", + " # Install Git LFS\n", + " !git lfs install\n", "\n", - "#@title Install Pre-trained Model \n", + " # Clone the Spaces repository\n", + " !git clone https://huggingface.co/spaces/Salesforce/BLIP\n", "\n", - "installModels=[]\n", + " # Download the BLIP model weights\n", + " !wget -c {BLIP_WEIGHT_SOURCE_URL} -O {BLIP_WEIGHT_DESTINATION_PATH}\n", "\n", - "#@markdown ### Available Model\n", - "#@markdown Select one of available pretrained model to download:\n", - "modelUrl = [\"\", \\\n", - " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \\\n", - " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/modelsfw-pruned.ckpt\", \\\n", - " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp16.ckpt\", \\\n", - " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp32.ckpt\", \\\n", - " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned.ckpt\", \\\n", - " \"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt\", \\\n", - " \"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt\", \\\n", - " \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float32.ckpt\"]\n", - "modelList = [\"\", \\\n", - " \"Animefull-final-pruned\", \\\n", - " \"Animesfw-final-pruned\", \\\n", - " \"Anything-V3.0-pruned-fp16\", \\\n", - " \"Anything-V3.0-pruned-fp32\", \\\n", - " \"Anything-V3.0-pruned\", \\\n", - " \"Stable-Diffusion-v1-4\", \\\n", - " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n", - " \"Waifu-Diffusion-v1-3-fp32\"]\n", - "modelName = \"Anything-V3.0-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n", + " # Move the make_captions.py script to the BLIP directory\n", + " if os.path.exists(MAKE_CAPTION_SOURCE_PATH):\n", + " shutil.move(MAKE_CAPTION_SOURCE_PATH, MAKE_CAPTION_DESTINATION_PATH)\n", + " else:\n", + " pass\n", "\n", - "#@markdown ### Custom model\n", - "#@markdown The model URL should be a direct download link.\n", - "customName = \"\" #@param {'type': 'string'}\n", - "customUrl = \"\"#@param {'type': 'string'}\n", + " # Clone and prepare Spaces\n", + " clone_and_prepare_spaces()\n", "\n", - "# Check if user has specified a custom model\n", - "if customName != \"\" and customUrl != \"\":\n", - " # Add custom model to list of models to install\n", - " installModels.append((customName, customUrl))\n", + " %cd /content/kohya-trainer/BLIP\n", "\n", - "# Check if user has selected a model\n", - "if modelName != \"\":\n", - " # Map selected model to URL\n", - " installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n", + " caption_weights = \"model_large_caption.pth\"\n", "\n", - "def install_aria():\n", - " # Install aria2 if it is not already installed\n", - " if not os.path.exists('/usr/bin/aria2c'):\n", - " !apt install -y -qq aria2\n", + " !python make_captions.py \\\n", + " \"{train_folder_directory}\" \\\n", + " {caption_weights} \\\n", + " --batch_size {global_batch_size} \\\n", + " --caption_extension .caption\n", + "else:\n", + " pass\n", "\n", - "def install(checkpoint_name, url):\n", - " if url.startswith(\"https://drive.google.com\"):\n", - " # Use gdown to download file from Google Drive\n", - " !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n", - " elif url.startswith(\"magnet:?\"):\n", - " install_aria()\n", - " # Use aria2c to download file from magnet link\n", - " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n", - " else:\n", - " user_token = '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'\n", - " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n", - " # Use wget to download file from URL\n", - " !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n", + "if use_wd_1_4_tagger:\n", + " # Change the working directory to the weight directory\n", + " %cd /content/kohya-trainer/diffuser_fine_tuning\n", "\n", - "def install_checkpoint():\n", - " # Iterate through list of models to install\n", - " for model in installModels:\n", - " # Call install function for each model\n", - " install(model[0], model[1])\n", + " !python tag_images_by_wd14_tagger.py \\\n", + " \"{train_folder_directory}\" \\\n", + " --batch_size {global_batch_size} \\\n", + " --caption_extension .txt\n", + "else:\n", + " pass\n" + ], + "metadata": { + "cellView": "form", + "id": "nvPyH-G_Qdha" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Datasets cleaner\n", + "#@markdown This will delete unnecessary files and unsupported media like `.mp4`, `.webm`, and `.gif`\n", "\n", - "# Call install_checkpoint function to download all models in the list\n", - "install_checkpoint()\n" + "%cd /content\n", + "\n", + "import os\n", + "\n", + "folder_target = \"/content/dreambooth/train_1girl/20_makora 1girl\" #@param {'type' : 'string'}\n", + "\n", + "test = os.listdir(folder_target)\n", + "\n", + "#@markdown I recommend to `keep_metadata` especially if you're doing resume training and you have metadata and bucket latents file from previous training like `.npz`, `.txt`, and `.caption`.\n", + "keep_metadata = True #@param {'type':'boolean'}\n", + "\n", + "# List of supported file types\n", + "if keep_metadata == True:\n", + " supported_types = [\".jpg\", \".jpeg\", \".png\", \".caption\", \".npz\", \".txt\"]\n", + "else:\n", + " supported_types = [\".jpg\", \".jpeg\", \".png\"]\n", + "\n", + "# Iterate over all files in the directory\n", + "for item in test:\n", + " # Extract the file extension from the file name\n", + " file_ext = os.path.splitext(item)[1]\n", + " # If the file extension is not in the list of supported types, delete the file\n", + " if file_ext not in supported_types:\n", + " # Print a message indicating the name of the file being deleted\n", + " print(f\"Deleting file {item} from {folder_target}\")\n", + " # Delete the file\n", + " os.remove(os.path.join(folder_target, item))\n" ], "metadata": { - "id": "SoucgZQ6jgPQ", - "cellView": "form" + "cellView": "form", + "id": "Ug648uiOvUZn" }, "execution_count": null, "outputs": [] @@ -430,26 +672,61 @@ "source": [ "#@title Training begin\n", "num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n", + "save_state = True #@param {'type':'boolean'}\n", "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Anything-V3.0-pruned.ckpt\" #@param {'type':'string'}\n", - "train_data_dir = \"/content/dreambooth/train_hitokomoru\" #@param {'type':'string'}\n", - "reg_data_dir = \"/content/dreambooth/reg_hitokomoru\" #@param {'type':'string'}\n", - "output_dir =\"/content/dreambooth\" #@param {'type':'string'}\n", + "train_data_dir = \"/content/makora-dreambooth/dreambooth/train_1girl\" #@param {'type':'string'}\n", + "reg_data_dir = \"/content/makora-dreambooth/dreambooth/reg_1girl\" #@param {'type':'string'}\n", + "output_dir = \"/content/drive/MyDrive/dreambooth\" #@param {'type':'string'}\n", + "resume_path =\"\" #@param {'type':'string'}\n", "train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n", - "resolution = \"512,512\" #@param [\"512,512\", \"768,768\"] {allow-input: false}\n", + "resolution = \"512\" #@param [\"512\", \"768\"] {allow-input: false}\n", "learning_rate =\"2e-6\" #@param {'type':'string'}\n", "mixed_precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n", - "max_train_steps = 5000 #@param {'type':'integer'}\n", + "max_train_steps = 1600 #@param {'type':'integer'}\n", "save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n", - "save_every_n_epochs = 10 #@param {'type':'integer'}\n", + "save_every_n_epochs = 50 #@param {'type':'integer'}\n", + "caption_extension = \".txt\" #@param [\"none\", \".caption\", \".txt\"] {allow-input: false}\n", + "\n", + "#@markdown ### Log And Debug\n", + "log_prefix = \"dreambooth-style1\" #@param {'type':'string'}\n", + "logs_dst = \"/content/kohya-trainer/logs\" #@param {'type':'string'}\n", + "debug_mode = False #@param {'type':'boolean'}\n", + "\n", + "if debug_mode == True:\n", + " debug_dataset = \"--debug_dataset\"\n", + "else:\n", + " debug_dataset = \"\"\n", + "\n", + "if save_state == True:\n", + " sv_state = \"--save_state\"\n", + "else:\n", + " sv_state = \"\"\n", + "\n", + "if resume_path != \"\":\n", + " rs_state = \"--resume \" + str(resume_path)\n", + "else:\n", + " rs_state = \"\"\n", + "\n", + "if caption_extension != \"none\":\n", + " captions_n_tags = \"--caption_extension =\" + str(caption_extension)\n", + " shuffle =\"--shuffle_caption\"\n", + "else:\n", + " captions_n_tags = \"\"\n", + " shuffle = \"\"\n", "\n", + " \n", "%cd /content/kohya-trainer/train_db_fixed\n", - "!accelerate launch --config_file {accelerate_config} --num_cpu_threads_per_process {num_cpu_threads_per_process} train_db_fixed.py \\\n", + "!accelerate launch \\\n", + " --config_file /content/kohya-trainer/accelerate_config/config.yaml \\\n", + " --num_cpu_threads_per_process {num_cpu_threads_per_process} \\\n", + " train_db_fixed.py \\\n", " --pretrained_model_name_or_path={pre_trained_model_path} \\\n", " --train_data_dir={train_data_dir} \\\n", " --reg_data_dir={reg_data_dir} \\\n", " --output_dir={output_dir} \\\n", " --prior_loss_weight=1.0 \\\n", " --resolution={resolution} \\\n", + " --save_precision {save_precision} \\\n", " --train_batch_size={train_batch_size}\\\n", " --learning_rate={learning_rate}\\\n", " --max_train_steps={max_train_steps} \\\n", @@ -459,7 +736,14 @@ " --gradient_checkpointing \\\n", " --save_every_n_epochs={save_every_n_epochs} \\\n", " --enable_bucket \\\n", - " --cache_latents \n" + " --cache_latents \\\n", + " {shuffle} \\\n", + " {debug_dataset} \\\n", + " {captions_n_tags} \\\n", + " {sv_state} \\\n", + " {rs_state} \\\n", + " --logging_dir={logs_dst} \\\n", + " --log_prefix {log_prefix}\n" ], "metadata": { "id": "X_Rd3Eh07xlA", @@ -471,45 +755,91 @@ { "cell_type": "code", "source": [ + "#@title Inference\n", + "prompt = \"masterpiece, best quality, makora 1girl, solo, \" #@param {type: \"string\"}\n", + "negative = \"lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\" #@param {type: \"string\"}\n", + "model = \"/content/drive/MyDrive/dreambooth/last.ckpt\" #@param {type: \"string\"}\n", + "vae = \"/content/kohya-trainer/checkpoint/anime.vae.pt\" #@param {type: \"string\"}\n", + "output_dir = \"/content/tmp\" #@param {type: \"string\"}\n", + "scale = 12 #@param {type: \"slider\", min: 1, max: 40}\n", + "sampler = \"ddim\" #@param [\"ddim\", \"pndm\", \"lms\", \"euler\", \"euler_a\", \"heun\", \"dpm_2\", \"dpm_2_a\", \"dpmsolver\",\"dpmsolver++\", \"dpmsingle\", \"k_lms\", \"k_euler\", \"k_euler_a\", \"k_dpm_2\", \"k_dpm_2_a\"]\n", + "steps = 40 #@param {type: \"slider\", min: 1, max: 100}\n", + "precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n", + "width = 512 #@param {type: \"integer\"}\n", + "height = 512 #@param {type: \"integer\"}\n", + "batch_count = 6 #@param {type: \"integer\"}\n", + "batch_size = 1 #@param {type: \"integer\"}\n", + "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 40}\n", + "\n", + "!python /content/kohya-trainer/gen_img_diffusers/gen_img_diffusers.py \\\n", + " --ckpt {model} \\\n", + " --outdir {output_dir} \\\n", + " --xformers \\\n", + " --vae {vae} \\\n", + " --{precision} \\\n", + " --W {width} \\\n", + " --H {height} \\\n", + " --clip_skip {clip_skip} \\\n", + " --scale {scale} \\\n", + " --sampler {sampler} \\\n", + " --steps {steps} \\\n", + " --max_embeddings_multiples 3 \\\n", + " --batch_size {batch_size} \\\n", + " --images_per_prompt {batch_count} \\\n", + " --prompt \"{prompt} --n {negative}\"\n", + "\n" + ], + "metadata": { + "cellView": "form", + "id": "x6t8PbnF0gbg" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd /content/kohya-trainer/convert_diffusers20_original_sd\n", + "\n", "#@title Convert Weight to Diffusers or `.ckpt/.safetensors` (Optional)\n", "#@markdown ## Define weight path\n", - "weight = \"/content/kohya-trainer/fine-tuned/last.ckpt\" #@param {'type': 'string'}\n", - "weight_dir = {os.path.dirname(folder_path)}\n", + "weight = \"/content/kohya-trainer/fine-tuned/model.ckpt\" #@param {'type': 'string'}\n", + "weight_dir = os.path.dirname(weight)\n", "convert = \"diffusers_to_ckpt_safetensors\" #@param [\"diffusers_to_ckpt_safetensors\", \"ckpt_safetensors_to_diffusers\"] {'allow-input': false}\n", "\n", "#@markdown ## Conversion Config\n", "#@markdown\n", "#@markdown ### Diffusers to `.ckpt/.safetensors`\n", - "use_safetensors = False #@param {'type': 'boolean'}\n", + "use_safetensors = True #@param {'type': 'boolean'}\n", "\n", "if use_safetensors:\n", - " checkpoint = f\"{weight_dir}/model.safetensors\"\n", + " checkpoint = str(weight_dir)+\"/model.safetensors\"\n", "else:\n", - " checkpoint = f\"{weight_dir}/model.ckpt\"\n", + " checkpoint = str(weight_dir)+\"/model.ckpt\"\n", "\n", - "save_precision = \"--fp16\" #@param [\"--fp16\",\"--bf16\",\"--float\"] {'allow-input': false}\n", + "save_precision = \"--float\" #@param [\"--fp16\",\"--bf16\",\"--float\"] {'allow-input': false}\n", "\n", "#@markdown ### `.ckpt/.safetensors` to Diffusers\n", "#@markdown is your model v1 or v2 based Stable Diffusion Model\n", - " version = \"--v1\" #@param [\"--v1\",\"--v2\"] {'allow-input': false}\n", - " diffusers = f\"{weight_dir}/diffusers_model\" \n", + "version = \"--v1\" #@param [\"--v1\",\"--v2\"] {'allow-input': false}\n", + "diffusers = str(weight_dir)+\"/diffusers_model\"\n", "\n", "#@markdown Add reference model to get scheduler, optimizer, and tokenizer, because `.ckpt/.safetensors` didn't have one.\n", - "reference_model =\"\" #@param {'type': 'string'}\n", + "reference_model =\"runwayml/stable-diffusion-v1-5\" #@param {'type': 'string'}\n", "\n", "if convert == \"diffusers_to_ckpt_safetensors\":\n", " if not weight.endswith(\".ckpt\") or weight.endswith(\".safetensors\"):\n", - " !python convert_diffusers20_original_sd/convert_diffusers20_original_sd.py \\\n", + " !python convert_diffusers20_original_sd.py \\\n", " {weight} \\\n", " {checkpoint} \\\n", " {save_precision}\n", "\n", "else: \n", - " !python convert_diffusers20_original_sd/convert_diffusers20_original_sd.py \\\n", + " !python convert_diffusers20_original_sd.py \\\n", " {weight} \\\n", - " {diffusers} \\ \n", - " {v2} \\\n", - " --reference_model {reference_model} \n" + " {diffusers} \\\n", + " {version} \\\n", + " --reference_model {reference_model} " ], "metadata": { "cellView": "form", diff --git a/kohya-trainer.ipynb b/kohya-trainer.ipynb index 6e1bc433..b3d0bd8f 100644 --- a/kohya-trainer.ipynb +++ b/kohya-trainer.ipynb @@ -7,7 +7,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -16,7 +16,7 @@ "id": "slgjeYgd6pWp" }, "source": [ - "# Kohya Trainer V9 - VRAM 12GB\n", + "# Kohya Trainer V10 - VRAM 12GB\n", "### The Best Way for People Without Good GPUs to Fine-Tune the Stable Diffusion Model" ] }, @@ -66,89 +66,7 @@ " !git clone https://github.com/Linaqruf/kohya-trainer\n", "\n", "# Clone or update the Kohya Trainer repository\n", - "clone_kohya_trainer()\n", - "\n", - "# Change the current working directory to \"/content/kohya-trainer\".\n", - "%cd /content/kohya-trainer\n", - "\n", - "# Import `shutil` and `os` modules.\n", - "import shutil\n", - "import os\n", - "\n", - "# Initialize an empty list `custom_versions`.\n", - "custom_versions = []\n", - "\n", - "# Initialize a list `version_urls` containing URLs of different versions of the `diffusers_fine_tuning` file.\n", - "version_urls = [\"\",\\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v9/diffusers_fine_tuning_v9.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v8/diffusers_fine_tuning_v8.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v7/diffusers_fine_tuning_v7.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v6/diffusers_fine_tuning_v6.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v5/diffusers_fine_tuning_v5.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v4/diffusers_fine_tuning_v4.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v3/diffusers_fine_tuning_v3.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v2/diffusers_fine_tuning_v2.zip\", \\\n", - " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v1/diffusers_fine_tuning_v1.zip\"]\n", - "\n", - "# Initialize a list `version_names` containing names of different versions of the `diffusers_fine_tuning` file.\n", - "version_names = [\"latest_version\", \\\n", - " \"diffusers_fine_tuning_v9\", \\\n", - " \"diffusers_fine_tuning_v8\", \\\n", - " \"diffusers_fine_tuning_v7\", \\\n", - " \"diffusers_fine_tuning_v6\", \\\n", - " \"diffusers_fine_tuning_v5\", \\\n", - " \"diffusers_fine_tuning_v4\", \\\n", - " \"diffusers_fine_tuning_v3\", \\\n", - " \"diffusers_fine_tuning_v2\", \\\n", - " \"diffusers_fine_tuning_v1\"]\n", - "\n", - "# Initialize a variable `selected_version` to the selected version of the `diffusers_fine_tuning` file.\n", - "selected_version = \"latest_version\" #@param [\"latest_version\", \"diffusers_fine_tuning_v9\", \"diffusers_fine_tuning_v8\", \"diffusers_fine_tuning_v7\", \"diffusers_fine_tuning_v6\", \"diffusers_fine_tuning_v5\", \"diffusers_fine_tuning_v4\", \"diffusers_fine_tuning_v3\", \"diffusers_fine_tuning_v2\", \"diffusers_fine_tuning_v1\"]\n", - "\n", - "# Append a tuple to `custom_versions`, containing `selected_version` and the corresponding item\n", - "# in `version_urls`.\n", - "custom_versions.append((selected_version, version_urls[version_names.index(selected_version)]))\n", - "\n", - "# Define `download` function to download a file from the given URL and save it with\n", - "# the given name.\n", - "def download(name, url):\n", - " !wget -c \"{url}\" -O /content/{name}.zip\n", - "\n", - "# Define `unzip` function to unzip a file with the given name to a specified\n", - "# directory.\n", - "def unzip(name):\n", - " !unzip /content/{name}.zip -d /content/kohya-trainer/diffuser_fine_tuning\n", - "\n", - "# Define `download_version` function to download and unzip a file from `custom_versions`,\n", - "# unless `selected_version` is \"latest_version\".\n", - "def download_version():\n", - " if selected_version != \"latest_version\":\n", - " for zip in custom_versions:\n", - " download(zip[0], zip[1])\n", - "\n", - " # Rename the existing `diffuser_fine_tuning` directory to the `tmp` directory and delete any existing `tmp` directory.\n", - " if os.path.exists(\"/content/kohya-trainer/tmp\"):\n", - " shutil.rmtree(\"/content/kohya-trainer/tmp\")\n", - " os.rename(\"/content/kohya-trainer/diffuser_fine_tuning\", \"/content/kohya-trainer/tmp\")\n", - "\n", - " # Create a new empty `diffuser_fine_tuning` directory.\n", - " os.makedirs(\"/content/kohya-trainer/diffuser_fine_tuning\")\n", - " \n", - " # Unzip the downloaded file to the new `diffuser_fine_tuning` directory.\n", - " unzip(zip[0])\n", - " \n", - " # Delete the downloaded and unzipped file.\n", - " os.remove(\"/content/{}.zip\".format(zip[0]))\n", - " \n", - " # Inform the user that the existing `diffuser_fine_tuning` directory has been renamed to the `tmp` directory\n", - " # and a new empty `diffuser_fine_tuning` directory has been created.\n", - " print(\"Renamed existing 'diffuser_fine_tuning' directory to 'tmp' directory and created new empty 'diffuser_fine_tuning' directory.\")\n", - " else:\n", - " # Do nothing if `selected_version` is \"latest_version\".\n", - " pass\n", - "\n", - "# Call `download_version` function.\n", - "download_version()\n" + "clone_kohya_trainer()" ] }, { @@ -165,26 +83,18 @@ "\n", "import os\n", "\n", + "Install_xformers = True #@param {'type':'boolean'}\n", + " \n", "def install_dependencies():\n", " #@markdown This will install required Python packages\n", - " !pip install --upgrade -r script/requirements.txt\n", + " !pip install --upgrade -r requirements.txt\n", " !pip install -U gallery-dl\n", - " !pip install huggingface_hub\n", "\n", - " # Install WD1.4 Tagger dependencies\n", - " !pip install tensorflow\n", - " \n", - " Install_xformers = True #@param {'type':'boolean'}\n", - " \n", " if Install_xformers:\n", " !pip install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl\n", " else:\n", " pass\n", "\n", - " # Install BLIP dependencies\n", - " !pip install timm==0.4.12\n", - " !pip install fairscale==0.4.4\n", - "\n", "# Install dependencies\n", "install_dependencies()\n", "\n", @@ -198,6 +108,22 @@ "write_basic_config(save_location = accelerate_config) # Write a config file" ] }, + { + "cell_type": "code", + "source": [ + "#@title Restart Runtime\n", + "\n", + "import IPython\n", + "app = IPython.Application.instance()\n", + "app.kernel.do_shutdown(True)\n" + ], + "metadata": { + "cellView": "form", + "id": "iGR1wf_fEg0S" + }, + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", "metadata": { @@ -264,13 +190,13 @@ { "cell_type": "code", "source": [ - "#@title Define Train Data\n", + "#@title Define Train Data Directory\n", "#@markdown Define where your train data will be located. This cell will also create a folder based on your input. \n", "#@markdown This folder will be used as the target folder for scraping, tagging, bucketing, and training in the next cell.\n", "\n", "import os\n", "\n", - "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n", + "train_data_dir = \"/content/fine_tune/train_data\" #@param {'type' : 'string'}\n", "\n", "if not os.path.exists(train_data_dir):\n", " os.makedirs(train_data_dir)\n", @@ -412,7 +338,7 @@ "\n", "import os\n", "\n", - "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n", + "train_data_dir = \"/content/train_data\" #@param {'type' : 'string'}\n", "\n", "test = os.listdir(train_data_dir)\n", "\n", @@ -450,49 +376,14 @@ "cell_type": "code", "source": [ "#@title Start BLIP Captioning\n", - "%cd /content/kohya-trainer\n", - "\n", - "import shutil\n", - "import os\n", - "\n", - "def clone_and_prepare_spaces():\n", - " \"\"\"\n", - " Clones the Spaces repository, downloads the BLIP model weights, and moves the make_captions.py script to the BLIP directory.\n", - " \"\"\"\n", - " # Constants\n", - " BLIP_WEIGHT_SOURCE_URL = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'\n", - " BLIP_WEIGHT_DESTINATION_PATH = '/content/kohya-trainer/BLIP/model_large_caption.pth'\n", - " MAKE_CAPTION_SOURCE_PATH = '/content/kohya-trainer/diffuser_fine_tuning/make_captions.py'\n", - " MAKE_CAPTION_DESTINATION_PATH = '/content/kohya-trainer/BLIP/make_captions.py'\n", - "\n", - " # Install Git LFS\n", - " !git lfs install\n", - "\n", - " # Clone the Spaces repository\n", - " !git clone https://huggingface.co/spaces/Salesforce/BLIP\n", - "\n", - " # Download the BLIP model weights\n", - " !wget -c {BLIP_WEIGHT_SOURCE_URL} -O {BLIP_WEIGHT_DESTINATION_PATH}\n", - "\n", - " # Move the make_captions.py script to the BLIP directory\n", - " if os.path.exists(MAKE_CAPTION_SOURCE_PATH):\n", - " shutil.move(MAKE_CAPTION_SOURCE_PATH, MAKE_CAPTION_DESTINATION_PATH)\n", - " else:\n", - " pass\n", - "\n", - "# Clone and prepare Spaces\n", - "clone_and_prepare_spaces()\n", - "\n", - "%cd /content/kohya-trainer/BLIP\n", + "%cd /content/kohya-trainer/finetune\n", "\n", "#@markdown ### Define parameter:\n", "batch_size = 8 #@param {'type':'integer'}\n", "caption_extension = \".caption\" #@param [\".txt\",\".caption\"]\n", - "caption_weights = \"model_large_caption.pth\"\n", "\n", "!python make_captions.py \\\n", " {train_data_dir} \\\n", - " {caption_weights} \\\n", " --batch_size {batch_size} \\\n", " --caption_extension {caption_extension}" ], @@ -515,10 +406,10 @@ "#@title Start WD 1.4 Tagger\n", "\n", "# Change the working directory to the weight directory\n", - "%cd /content/kohya-trainer/diffuser_fine_tuning\n", + "%cd /content/kohya-trainer/finetune\n", "\n", "#@markdown ### Define parameter:\n", - "batch_size = 4 #@param {'type':'integer'}\n", + "batch_size = 8 #@param {'type':'integer'}\n", "caption_extension = \".txt\" #@param [\".txt\",\".caption\"]\n", "\n", "!python tag_images_by_wd14_tagger.py \\\n", @@ -532,12 +423,12 @@ "source": [ "#@title Create meta_clean.json \n", "# Change the working directory\n", - "%cd /content/kohya-trainer/diffuser_fine_tuning\n", + "%cd /content/kohya-trainer/finetune\n", "\n", "#@markdown ### Define Parameters\n", - "meta_cap_dd = \"/content/kohya-trainer/meta_cap_dd.json\" \n", - "meta_cap = \"/content/kohya-trainer/meta_cap.json\" \n", - "meta_clean = \"/content/kohya-trainer/meta_clean.json\" #@param {'type':'string'}\n", + "meta_cap_dd = \"/content/fine_tune/meta_cap_dd.json\" \n", + "meta_cap = \"/content/fine_tune/meta_cap.json\" \n", + "meta_clean = \"/content/fine_tune/meta_clean.json\" #@param {'type':'string'}\n", "\n", "# Check if the train_data_dir exists and is a directory\n", "if os.path.isdir(train_data_dir):\n", @@ -602,20 +493,24 @@ "outputs": [], "source": [ "#@title Install Pre-trained Model \n", - "%cd /content/kohya-trainer\n", + "%cd /content/\n", "import os\n", "\n", "# Check if directory exists\n", - "if not os.path.exists('checkpoint'):\n", + "if not os.path.exists('pre_trained_model'):\n", " # Create directory if it doesn't exist\n", - " os.makedirs('checkpoint')\n", + " os.makedirs('pre_trained_model')\n", "\n", "#@title Install Pre-trained Model \n", "\n", - "installModels=[]\n", + "installModels = []\n", + "installVae = []\n", + "installVaeArgs = []\n", + "installv2Models = []\n", "\n", "#@markdown ### Available Model\n", "#@markdown Select one of available pretrained model to download:\n", + "#@markdown ### SD1.x model\n", "modelUrl = [\"\", \\\n", " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \\\n", " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/modelsfw-pruned.ckpt\", \\\n", @@ -634,12 +529,39 @@ " \"Stable-Diffusion-v1-4\", \\\n", " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n", " \"Waifu-Diffusion-v1-3-fp32\"]\n", - "modelName = \"Anything-V3.0-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n", + "modelName = \"\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n", + "\n", + "#@markdown ### SD2.x model\n", + "v2ModelUrl = [\"\", \\\n", + " \"https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt\", \\\n", + " \"https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt\", \\\n", + " \"https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt\", \\\n", + " \"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt\", \\\n", + " \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt\"]\n", + "v2ModelList = [\"\", \\\n", + " \"stable-diffusion-2-base\", \\\n", + " \"stable-diffusion-2-768v\", \\\n", + " \"stable-diffusion-2-1-base\", \\\n", + " \"stable-diffusion-2-1-768v\", \\\n", + " \"waifu-diffusion-1-4-anime-e-1\"]\n", + "v2ModelName = \"waifu-diffusion-1-4-anime-e-1\" #@param [\"\", \"stable-diffusion-2-base\", \"stable-diffusion-2-768v\", \"stable-diffusion-2-1-base\", \"stable-diffusion-2-1-768v\", \"waifu-diffusion-1-4-anime-e-1\"]\n", "\n", "#@markdown ### Custom model\n", "#@markdown The model URL should be a direct download link.\n", "customName = \"\" #@param {'type': 'string'}\n", - "customUrl = \"\"#@param {'type': 'string'}\n", + "customUrl = \"\" #@param {'type': 'string'}\n", + "\n", + "\n", + "#@markdown Select one of the VAEs to download, select `none` for not download VAE:\n", + "vaeUrl = [\"\", \\\n", + " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animevae/animevae.pt\", \\\n", + " \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\", \\\n", + " \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\"]\n", + "vaeList = [\"none\", \\\n", + " \"anime.vae.pt\", \\\n", + " \"waifudiffusion.vae.pt\", \\\n", + " \"stablediffusion.vae.pt\"]\n", + "vaeName = \"waifudiffusion.vae.pt\" #@param [\"none\", \"anime.vae.pt\", \"waifudiffusion.vae.pt\", \"stablediffusion.vae.pt\"]\n", "\n", "# Check if user has specified a custom model\n", "if customName != \"\" and customUrl != \"\":\n", @@ -651,24 +573,40 @@ " # Map selected model to URL\n", " installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n", "\n", + "# Check if user has selected a model\n", + "if v2ModelName != \"\":\n", + " # Map selected model to URL\n", + " installv2Models.append((v2ModelName, v2ModelUrl[v2ModelList.index(v2ModelName)]))\n", + "\n", + "installVae.append((vaeName, vaeUrl[vaeList.index(vaeName)]))\n", + "\n", "def install_aria():\n", " # Install aria2 if it is not already installed\n", " if not os.path.exists('/usr/bin/aria2c'):\n", " !apt install -y -qq aria2\n", "\n", "def install(checkpoint_name, url):\n", + " if url.endswith(\".ckpt\"):\n", + " dst = \"/content/pre_trained_model/\" + str(checkpoint_name) + \".ckpt\"\n", + " elif url.endswith(\".safetensors\"):\n", + " dst = \"/content/pre_trained_model/\" + str(checkpoint_name) + \".safetensors\"\n", + " elif url.endswith(\".pt\"):\n", + " dst = \"/content/pre_trained_model/\" + str(checkpoint_name)\n", + " else:\n", + " dst = \"/content/pre_trained_model/\" + str(checkpoint_name) + \".ckpt\"\n", + "\n", " if url.startswith(\"https://drive.google.com\"):\n", " # Use gdown to download file from Google Drive\n", - " !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n", + " !gdown --fuzzy -O {dst} \"{url}\"\n", " elif url.startswith(\"magnet:?\"):\n", " install_aria()\n", " # Use aria2c to download file from magnet link\n", - " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n", + " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o {dst} \"{url}\"\n", " else:\n", " user_token = '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'\n", " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n", " # Use wget to download file from URL\n", - " !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n", + " !wget -c --header={user_header} \"{url}\" -O {dst}\n", "\n", "def install_checkpoint():\n", " # Iterate through list of models to install\n", @@ -676,6 +614,17 @@ " # Call install function for each model\n", " install(model[0], model[1])\n", "\n", + " # Iterate through list of models to install\n", + " for v2model in installv2Models:\n", + " # Call install function for each v2model\n", + " install(v2model[0], v2model[1])\n", + " \n", + " if vaeName != \"none\":\n", + " for vae in installVae:\n", + " install(vae[0], vae[1])\n", + " else:\n", + " pass\n", + "\n", "# Call install_checkpoint function to download all models in the list\n", "install_checkpoint()\n" ] @@ -692,23 +641,28 @@ "#@title Aspect Ratio Bucketing\n", "\n", "# Change working directory\n", - "%cd /content/kohya-trainer/diffuser_fine_tuning\n", + "%cd /content/kohya-trainer/finetune\n", "\n", "#@markdown ### Define parameters\n", - "\n", - "model_dir = \"Linaqruf/hitokomoru-diffusion\" #@param {'type' : 'string'} \n", + "V2 = True #@param{type:\"boolean\"}\n", + "model_dir = \"/content/pre_trained_model/waifu-diffusion-1-4-anime-e-1.ckpt\" #@param {'type' : 'string'} \n", + "input_json = \"/content/fine_tune/meta_clean.json\" #@param {'type' : 'string'} \n", + "output_json = \"/content/fine_tune/meta_lat.json\"#@param {'type' : 'string'} \n", "batch_size = 4 #@param {'type':'integer'}\n", - "max_resolution = \"512,512\" #@param [\"512,512\", \"768,768\"] {allow-input: false}\n", + "max_resolution = \"768,768\" #@param [\"512,512\", \"640,640\", \"768,768\"] {allow-input: false}\n", "mixed_precision = \"no\" #@param [\"no\", \"fp16\", \"bf16\"] {allow-input: false}\n", - "meta_clean = \"/content/kohya-trainer/meta_clean.json\"\n", - "meta_lat = \"/content/kohya-trainer/meta_lat.json\"\n", "\n", + "if V2:\n", + " SDV2 = \"--v2\"\n", + "else:\n", + " SDV2 = \"\"\n", "# Run script to prepare buckets and latents\n", "!python prepare_buckets_latents.py \\\n", " {train_data_dir} \\\n", - " {meta_clean} \\\n", - " {meta_lat} \\\n", + " {input_json} \\\n", + " {output_json} \\\n", " {model_dir} \\\n", + " {SDV2} \\\n", " --batch_size {batch_size} \\\n", " --max_resolution {max_resolution} \\\n", " --mixed_precision {mixed_precision}\n", @@ -734,10 +688,10 @@ "#@title Define Important folder\n", "import os\n", "\n", - "pre_trained_model_path =\"Linaqruf/hitokomoru-diffusion\" #@param {'type':'string'}\n", - "meta_lat_json_dir = \"/content/kohya-trainer/meta_lat.json\" #@param {'type':'string'}\n", - "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type':'string'}\n", - "output_dir =\"/content/kohya-trainer/fine-tuned\" #@param {'type':'string'}\n", + "pre_trained_model_path =\"/content/pre_trained_model/waifu-diffusion-1-4-anime-e-1.ckpt\" #@param {'type':'string'}\n", + "meta_lat_json_dir = \"/content/fine_tune/meta_lat.json\" #@param {'type':'string'}\n", + "train_data_dir = \"/content/fine_tune/train_data\" #@param {'type':'string'}\n", + "output_dir =\"/content/fine_tune/output\" #@param {'type':'string'}\n", "resume_path = \"\" #@param {'type':'string'}\n", "\n", "# List of important folder paths\n", @@ -780,24 +734,47 @@ "source": [ "#@title Training begin\n", "#@markdown ### Define Parameters\n", + "\n", + "V2 = \"V2_base\" #@param [\"none\", \"V2_base\", \"V2_768_v\"] {allow-input: false}\n", "num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n", "save_state = True #@param {'type':'boolean'}\n", "train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n", "learning_rate =\"2e-6\" #@param {'type':'string'}\n", - "max_train_steps = 1000 #@param {'type':'integer'}\n", + "max_train_steps = 2500 #@param {'type':'integer'}\n", "train_text_encoder = False #@param {'type':'boolean'}\n", "lr_scheduler = \"constant\" #@param [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"] {allow-input: false}\n", "max_token_length = \"225\" #@param [\"150\", \"225\"] {allow-input: false}\n", "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n", - "mixed_precision = \"fp16\" #@param [\"no\",\"fp16\",\"bf16\"] {allow-input: false}\n", + "mixed_precision = \"no\" #@param [\"no\",\"fp16\",\"bf16\"] {allow-input: false}\n", + "save_model_as = \"ckpt\" #@param [\"default\", \"ckpt\", \"safetensors\", \"diffusers\", \"diffusers_safetensors\"] {allow-input: false}\n", "save_precision = \"fp16\" #@param [\"None\",\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n", "save_every_n_epochs = 50 #@param {'type':'integer'}\n", "gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n", "#@markdown ### Log And Debug\n", "log_prefix = \"fine-tune-style1\" #@param {'type':'string'}\n", - "logs_dst = \"/content/kohya-trainer/logs\" #@param {'type':'string'}\n", + "logs_dst = \"/content/fine_tune/training_logs\" #@param {'type':'string'}\n", "debug_mode = False #@param {'type':'boolean'}\n", "\n", + "if V2 == \"V2_base\":\n", + " v2_model = \"--v2\"\n", + " v2_768v_model= \"\"\n", + "elif V2 == \"V2_768_v\":\n", + " v2_model = \"--v2\"\n", + " v2_768v_model = \"--v2_parameterization\"\n", + "else:\n", + " v2_model = \"\"\n", + " v2_768v_model = \"\"\n", + "\n", + "if V2 != \"\":\n", + " penultimate_layer = \"\"\n", + "else:\n", + " penultimate_layer = \"--clip_skip\" + \"=\" + \"{}\".format(clip_skip)\n", + "\n", + "if save_model_as != \"default\":\n", + " sv_model = \"--save_model_as= \" + str(save_model_as)\n", + "else: \n", + " sv_model = \"\"\n", + "\n", "if save_state == True:\n", " sv_state = \"--save_state\"\n", "else:\n", @@ -829,11 +806,14 @@ " text_encoder = \"\"\n", "\n", "\n", - "%cd /content/kohya-trainer/diffuser_fine_tuning\n", + "%cd /content/kohya-trainer\n", + "\n", "!accelerate launch \\\n", " --config_file /content/kohya-trainer/accelerate_config/config.yaml \\\n", " --num_cpu_threads_per_process {num_cpu_threads_per_process} \\\n", " fine_tune.py \\\n", + " {v2_model} \\\n", + " {v2_768v_model} \\\n", " --pretrained_model_name_or_path={pre_trained_model_path} \\\n", " --in_json {meta_lat_json_dir} \\\n", " --train_data_dir={train_data_dir} \\\n", @@ -843,13 +823,14 @@ " --learning_rate={learning_rate} \\\n", " --lr_scheduler={lr_scheduler} \\\n", " --max_token_length={max_token_length} \\\n", - " --clip_skip={clip_skip} \\\n", + " {penultimate_layer} \\\n", " --mixed_precision={mixed_precision} \\\n", " --max_train_steps={max_train_steps} \\\n", " --use_8bit_adam \\\n", " --xformers \\\n", " --gradient_checkpointing \\\n", " --gradient_accumulation_steps {gradient_accumulation_steps} \\\n", + " {sv_model} \\\n", " {text_encoder} \\\n", " {sv_state} \\\n", " {rs_state} \\\n", @@ -866,19 +847,87 @@ "id": "vqfgyL-thgdw" }, "source": [ - "# Miscellaneous" + "# Extras" ] }, { "cell_type": "code", "source": [ - "%cd /content/kohya-trainer/convert_diffusers20_original_sd\n", + "#@title Inference\n", + "V2 = \"V2_base\" #@param [\"none\", \"V2_base\", \"V2_768_v\"] {allow-input: false}\n", + "prompt = \"masterpiece, best quality, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt\" #@param {type: \"string\"}\n", + "negative = \"worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry\" #@param {type: \"string\"}\n", + "model = \"/content/fine_tune/output/last.ckpt\" #@param {type: \"string\"}\n", + "vae = \"/content/pre_trained_model/waifudiffusion.vae.pt\" #@param {type: \"string\"}\n", + "output_dir = \"/content/tmp\" #@param {type: \"string\"}\n", + "scale = 12 #@param {type: \"slider\", min: 1, max: 40}\n", + "sampler = \"ddim\" #@param [\"ddim\", \"pndm\", \"lms\", \"euler\", \"euler_a\", \"heun\", \"dpm_2\", \"dpm_2_a\", \"dpmsolver\",\"dpmsolver++\", \"dpmsingle\", \"k_lms\", \"k_euler\", \"k_euler_a\", \"k_dpm_2\", \"k_dpm_2_a\"]\n", + "steps = 28 #@param {type: \"slider\", min: 1, max: 100}\n", + "precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n", + "width = 768 #@param {type: \"integer\"}\n", + "height = 768 #@param {type: \"integer\"}\n", + "batch_count = 1 #@param {type: \"integer\"}\n", + "batch_size = 1 #@param {type: \"integer\"}\n", + "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 40}\n", + "\n", + "if vae != \"\":\n", + " load_vae =\"--vae \" + str(vae)\n", + "else:\n", + " load_vae =\"\" \n", + "\n", + "if V2 == \"V2_base\":\n", + " v2_model = \"--v2\"\n", + " v2_768v_model= \"\"\n", + "elif V2 == \"V2_768_v\":\n", + " v2_model = \"--v2\"\n", + " v2_768v_model = \"--v2_parameterization\"\n", + "else:\n", + " v2_model = \"\"\n", + " v2_768v_model = \"\"\n", + "\n", + "if V2 != \"\":\n", + " penultimate_layer = \"\"\n", + "else:\n", + " penultimate_layer = \"--clip_skip\" + \"=\" + \"{}\".format(clip_skip)\n", + "\n", + "%cd /content/kohya-trainer\n", + "!python gen_img_diffusers.py \\\n", + " {v2_model} \\\n", + " {v2_768v_model} \\\n", + " --ckpt {model} \\\n", + " --outdir {output_dir} \\\n", + " --xformers \\\n", + " {load_vae} \\\n", + " --{precision} \\\n", + " --W {width} \\\n", + " --H {height} \\\n", + " {penultimate_layer} \\\n", + " --scale {scale} \\\n", + " --sampler {sampler} \\\n", + " --steps {steps} \\\n", + " --max_embeddings_multiples 3 \\\n", + " --batch_size {batch_size} \\\n", + " --images_per_prompt {batch_count} \\\n", + " --prompt \"{prompt} --n {negative}\"\n", + "\n" + ], + "metadata": { + "cellView": "form", + "id": "j1jJ4z3AXRO9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%cd /content/kohya-trainer/tools\n", "\n", "#@title Convert Weight to Diffusers or `.ckpt/.safetensors` (Optional)\n", "#@markdown ## Define weight path\n", - "weight = \"/content/kohya-trainer/fine-tuned/model.ckpt\" #@param {'type': 'string'}\n", + "weight = \"/content/fine_tune/output/last.ckpt\" #@param {'type': 'string'}\n", "weight_dir = os.path.dirname(weight)\n", - "convert = \"diffusers_to_ckpt_safetensors\" #@param [\"diffusers_to_ckpt_safetensors\", \"ckpt_safetensors_to_diffusers\"] {'allow-input': false}\n", + "convert = \"ckpt_safetensors_to_diffusers\" #@param [\"diffusers_to_ckpt_safetensors\", \"ckpt_safetensors_to_diffusers\"] {'allow-input': false}\n", "\n", "#@markdown ## Conversion Config\n", "#@markdown\n", @@ -894,11 +943,11 @@ "\n", "#@markdown ### `.ckpt/.safetensors` to Diffusers\n", "#@markdown is your model v1 or v2 based Stable Diffusion Model\n", - "version = \"--v1\" #@param [\"--v1\",\"--v2\"] {'allow-input': false}\n", + "version = \"--v2\" #@param [\"--v1\",\"--v2\"] {'allow-input': false}\n", "diffusers = str(weight_dir)+\"/diffusers_model\"\n", "\n", "#@markdown Add reference model to get scheduler, optimizer, and tokenizer, because `.ckpt/.safetensors` didn't have one.\n", - "reference_model =\"runwayml/stable-diffusion-v1-5\" #@param {'type': 'string'}\n", + "reference_model =\"hakurei/waifu-diffusion\" #@param {'type': 'string'}\n", "\n", "if convert == \"diffusers_to_ckpt_safetensors\":\n", " if not weight.endswith(\".ckpt\") or weight.endswith(\".safetensors\"):\n", @@ -950,7 +999,7 @@ "input = \"/content/kohya-trainer/fine-tuned/model.ckpt\" #@param {'type' : 'string'}\n", "\n", "# Use a more descriptive variable name\n", - "output = \"/content/kohya-trainer/fine-tuned/model.ckpt\" #@param {'type' : 'string'}\n", + "output = \"/content/kohya-trainer/fine-tuned/model-pruned.ckpt\" #@param {'type' : 'string'}\n", "\n", "if prune:\n", " import os\n", @@ -977,9 +1026,11 @@ "cell_type": "code", "source": [ "#@title Visualize loss graph (Optional)\n", + "training_logs_path = \"/content/fine_tune/training_logs\" #@param {type : \"string\"}\n", + "\n", "%cd /content/kohya-trainer\n", "%load_ext tensorboard\n", - "%tensorboard --logdir {logs_dst}" + "%tensorboard --logdir {training_logs_path}" ], "metadata": { "cellView": "form", diff --git a/library/__init__.py b/library/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diffuser_fine_tuning/model_util.py b/library/model_util.py similarity index 99% rename from diffuser_fine_tuning/model_util.py rename to library/model_util.py index f3453025..398b6404 100644 --- a/diffuser_fine_tuning/model_util.py +++ b/library/model_util.py @@ -624,8 +624,16 @@ def convert_key(key): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd # endregion diff --git a/networks/lora.py b/networks/lora.py new file mode 100644 index 00000000..730a6376 --- /dev/null +++ b/networks/lora.py @@ -0,0 +1,190 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +import torch + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == 'Conv2d': + in_dim = org_module.in_channels + out_dim = org_module.out_channels + self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) + self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x): + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier + + +def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): + if network_dim is None: + network_dim = 4 # default + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) + return network + + +class LoRANetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + LORA_PREFIX_UNET = 'lora_unet' + LORA_PREFIX_TEXT_ENCODER = 'lora_te' + + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: + super().__init__() + self.multiplier = multiplier + self.lora_dim = lora_dim + + # create module instances + def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) + loras.append(lora) + return loras + + self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, + text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def load_weights(self, file): + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file + self.weights_sd = load_file(file) + else: + self.weights_sd = torch.load(file, map_location='cpu') + + def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None): + if self.weights_sd: + weights_has_text_encoder = weights_has_unet = False + for key in self.weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER): + weights_has_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_UNET): + weights_has_unet = True + + if apply_text_encoder is None: + apply_text_encoder = weights_has_text_encoder + else: + assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています" + + if apply_unet is None: + apply_unet = weights_has_unet + else: + assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています" + else: + assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set" + + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + if self.weights_sd: + # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) + info = self.load_state_dict(self.weights_sd, False) + print(f"weights are loaded: {info}") + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr): + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + self.requires_grad_(True) + params = [] + + if self.text_encoder_loras: + param_data = {'params': enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data['lr'] = text_encoder_lr + params.append(param_data) + + if self.unet_loras: + param_data = {'params': enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data['lr'] = unet_lr + params.append(param_data) + + return params + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype): + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import save_file + save_file(state_dict, file) + else: + torch.save(state_dict, file) diff --git a/networks/merge_lora.py b/networks/merge_lora.py new file mode 100644 index 00000000..d873a8ef --- /dev/null +++ b/networks/merge_lora.py @@ -0,0 +1,159 @@ + + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +import library.model_util as model_util +import lora + + +def load_state_dict(file_name, dtype): + if os.path.splitext(file_name)[1] == '.safetensors': + sd = load_file(file_name) + else: + sd = torch.load(file_name, map_location='cpu') + for key in list(sd.keys()): + if type(sd[key]) == torch.Tensor: + sd[key] = sd[key].to(dtype) + return sd + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): + text_encoder.to(merge_dtype) + unet.to(merge_dtype) + + # create module map + name_to_module = {} + for i, root_module in enumerate([text_encoder, unet]): + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + else: + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + name_to_module[lora_name] = child_module + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + + # find original module for this lora + module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + print(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # print(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + # W <- W + U * D + weight = module.weight + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) + else: + # conv2d + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + + module.weight = torch.nn.Parameter(weight) + + +def merge_lora_models(models, ratios, merge_dtype): + merged_sd = {} + + for model, ratio in zip(models, ratios): + print(f"loading: {model}") + lora_sd = load_state_dict(model, merge_dtype) + + print(f"merging...") + for key in lora_sd.keys(): + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + merged_sd[key] = lora_sd[key] * ratio + + return merged_sd + + +def merge(args): + assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください" + + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + merge_dtype = str_to_dtype(args.precision) + save_dtype = str_to_dtype(args.save_precision) + if save_dtype is None: + save_dtype = merge_dtype + + if args.sd_model is not None: + print(f"loading SD model: {args.sd_model}") + + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) + + merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) + + print(f"saving SD model to: {args.save_to}") + model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, + args.sd_model, 0, 0, save_dtype, vae) + else: + state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + + print(f"saving model to: {args.save_to}") + save_to_file(args.save_to, state_dict, state_dict, save_dtype) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") + parser.add_argument("--precision", type=str, default="float", + choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") + parser.add_argument("--sd_model", type=str, default=None, + help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--models", type=str, nargs='*', + help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors") + parser.add_argument("--ratios", type=float, nargs='*', + help="ratios for each model / それぞれのLoRAモデルの比率") + + args = parser.parse_args() + merge(args) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..36f48a0f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +accelerate==0.15.0 +transformers==4.25.1 +ftfy +albumentations +opencv-python +einops +diffusers[torch]==0.10.2 +pytorch_lightning +bitsandbytes==0.35.0 +tensorboard +safetensors==0.2.6 +gradio +altair +easygui +# for BLIP captioning +requests +timm==0.4.12 +fairscale==0.4.4 +# for WD14 captioning +tensorflow<2.11 +huggingface-hub +# for kohya_ss library +. \ No newline at end of file diff --git a/script/detect_face_rotate_v3.py b/script/detect_face_rotate_v3.py deleted file mode 100644 index 6e1c3225..00000000 --- a/script/detect_face_rotate_v3.py +++ /dev/null @@ -1,235 +0,0 @@ -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す - -# v2: extract max face if multiple faces are found -# v3: add crop_ratio option - -import argparse -import math -import cv2 -import glob -import os -from anime_face_detector import create_detector -from tqdm import tqdm -import numpy as np - -KP_REYE = 11 -KP_LEYE = 19 - -SCORE_THRES = 0.90 - - -def detect_face(detector, image): - preds = detector(image) # bgr - # print(len(preds)) - if len(preds) == 0: - return None, None, None, None, None - - index = -1 - max_score = 0 - max_size = 0 - for i in range(len(preds)): - bb = preds[i]['bbox'] - score = bb[-1] - size = max(bb[2]-bb[0], bb[3]-bb[1]) - if (score > max_score and max_score < SCORE_THRES) or (score >= SCORE_THRES and size > max_size): - index = i - max_score = score - max_size = size - - left = preds[index]['bbox'][0] - top = preds[index]['bbox'][1] - right = preds[index]['bbox'][2] - bottom = preds[index]['bbox'][3] - cx = int((left + right) / 2) - cy = int((top + bottom) / 2) - fw = int(right - left) - fh = int(bottom - top) - - lex, ley = preds[index]['keypoints'][KP_LEYE, 0:2] - rex, rey = preds[index]['keypoints'][KP_REYE, 0:2] - angle = math.atan2(ley - rey, lex - rex) - angle = angle / math.pi * 180 - return cx, cy, fw, fh, angle - - -def rotate_image(image, angle, cx, cy): - h, w = image.shape[0:2] - rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) - - # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 - # nh = max(h, int(w * math.sin(angle))) - # nw = max(w, int(h * math.sin(angle))) - # if nh > h or nw > w: - # pad_y = nh - h - # pad_t = pad_y // 2 - # pad_x = nw - w - # pad_l = pad_x // 2 - # m = np.array([[0, 0, pad_l], - # [0, 0, pad_t]]) - # rot_mat = rot_mat + m - # h, w = nh, nw - # cx += pad_l - # cy += pad_t - - result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) - return result, cx, cy - - -def process(args): - assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" - assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" - - # アニメ顔検出モデルを読み込む - print("loading face detector.") - detector = create_detector('yolov3') - - # cropの引数を解析する - if args.crop_size is None: - crop_width = crop_height = None - else: - tokens = args.crop_size.split(',') - assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" - crop_width, crop_height = [int(t) for t in tokens] - - if args.crop_ratio is None: - crop_h_ratio = crop_v_ratio = None - else: - tokens = args.crop_ratio.split(',') - assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" - crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] - - # 画像を処理する - print("processing.") - output_extension = ".png" - - os.makedirs(args.dst_dir, exist_ok=True) - paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.src_dir, "*.webp")) - for path in tqdm(paths): - basename = os.path.splitext(os.path.basename(path))[0] - - # image = cv2.imread(path) # 日本語ファイル名でエラーになる - image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) - if len(image.shape) == 2: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) - if image.shape[2] == 4: - print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") - image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい - - h, w = image.shape[:2] - - cx, cy, fw, fh, angle = detect_face(detector, image) - if cx is None: - print(f"face not found, skip: {path}") - # cx = cy = fw = fh = 0 - continue # スキップする - - # オプション指定があれば回転する - if args.rotate and cx != 0: - image, cx, cy = rotate_image(image, angle, cx, cy) - - # オプション指定があれば顔を中心に切り出す - if crop_width is not None or crop_h_ratio is not None: - assert cx > 0, f"face not found for cropping: {path}" - cur_crop_width, cur_crop_height = crop_width, crop_height - if crop_h_ratio is not None: - cur_crop_width = int(max(fw, fh) * crop_h_ratio + .5) - cur_crop_height = int(max(fw, fh) * crop_v_ratio + .5) - - # リサイズを必要なら行う - scale = 1.0 - if args.resize_face_size is not None: - # 顔サイズを基準にリサイズする - scale = args.resize_face_size / max(fw, fh) - if scale < cur_crop_width / w: - print( - f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") - scale = cur_crop_width / w - if scale < cur_crop_height / h: - print( - f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") - scale = cur_crop_height / h - elif crop_h_ratio is not None: - # 倍率指定の時にはリサイズしない - pass - else: - # 切り出しサイズ指定あり - if w < cur_crop_width: - print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") - scale = cur_crop_width / w - if h < cur_crop_height: - print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") - scale = cur_crop_height / h - if args.resize_fit: - scale = max(cur_crop_width / w, cur_crop_height / h) - - if scale != 1.0: - w = int(w * scale + .5) - h = int(h * scale + .5) - image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) - cx = int(cx * scale + .5) - cy = int(cy * scale + .5) - fw = int(fw * scale + .5) - fh = int(fh * scale + .5) - - cur_crop_width = min(cur_crop_width, image.shape[1]) - cur_crop_height = min(cur_crop_height, image.shape[0]) - - x = cx - cur_crop_width // 2 - cx = cur_crop_width // 2 - if x < 0: - cx = cx + x - x = 0 - elif x + cur_crop_width > w: - cx = cx + (x + cur_crop_width - w) - x = w - cur_crop_width - image = image[:, x:x+cur_crop_width] - - y = cy - cur_crop_height // 2 - cy = cur_crop_height // 2 - if y < 0: - cy = cy + y - y = 0 - elif y + cur_crop_height > h: - cy = cy + (y + cur_crop_height - h) - y = h - cur_crop_height - image = image[y:y + cur_crop_height] - - # # debug - # print(path, cx, cy, angle) - # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) - # cv2.imshow("image", crp) - # if cv2.waitKey() == 27: - # break - # cv2.destroyAllWindows() - - # debug - if args.debug: - cv2.rectangle(image, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) - - # cv2.imwrite(os.path.join(args.dst_dir, f"{basename}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}.{output_extension}"), image) - _, buf = cv2.imencode(output_extension, image) - with open(os.path.join(args.dst_dir, f"{basename}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: - buf.tofile(f) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") - parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") - parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") - parser.add_argument("--resize_fit", action="store_true", - help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") - parser.add_argument("--resize_face_size", type=int, default=None, - help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") - parser.add_argument("--crop_size", type=str, default=None, - help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") - parser.add_argument("--crop_ratio", type=str, default=None, - help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") - parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") - args = parser.parse_args() - - process(args) diff --git a/script/requirements.txt b/script/requirements.txt deleted file mode 100644 index 768ff785..00000000 --- a/script/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ -accelerate==0.15.0 -transformers>=4.21.0 -ftfy -albumentations -opencv-python -einops -diffusers[torch]==0.10.2 -pytorch_lightning -bitsandbytes -tensorboard -safetensors diff --git a/script/tag_images_by_wd14_tagger.py b/script/tag_images_by_wd14_tagger.py deleted file mode 100644 index 66d3a34e..00000000 --- a/script/tag_images_by_wd14_tagger.py +++ /dev/null @@ -1,107 +0,0 @@ -# このスクリプトのライセンスは、Apache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -import argparse -import csv -import glob -import os -import json - -from PIL import Image -from tqdm import tqdm -import numpy as np -from tensorflow.keras.models import load_model -from Utils import dbimutils - - -# from wd14 tagger -IMAGE_SIZE = 448 - - -def main(args): - image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) - print(f"found {len(image_paths)} images.") - - print("loading model and labels") - model = load_model(args.model) - - # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") - # 依存ライブラリを増やしたくないので自力で読むよ - with open(args.tag_csv, "r", encoding="utf-8") as f: - reader = csv.reader(f) - l = [row for row in reader] - header = l[0] # tag_id,name,category,count - rows = l[1:] - assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}" - - tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ - - # 推論する - def run_batch(path_imgs): - imgs = np.array([im for _, im in path_imgs]) - - probs = model(imgs, training=False) - probs = probs.numpy() - - for (image_path, _), prob in zip(path_imgs, probs): - # 最初の4つはratingなので無視する - # # First 4 labels are actually ratings: pick one with argmax - # ratings_names = label_names[:4] - # rating_index = ratings_names["probs"].argmax() - # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]] - - # それ以降はタグなのでconfidenceがthresholdより高いものを追加する - # Everything else is tags: pick any where prediction confidence > threshold - tag_text = "" - for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで - if p >= args.thresh: - tag_text += ", " + tags[i] - - if len(tag_text) > 0: - tag_text = tag_text[2:] # 最初の ", " を消す - - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f: - f.write(tag_text + '\n') - if args.debug: - print(image_path, tag_text) - - b_imgs = [] - for image_path in tqdm(image_paths): - img = dbimutils.smart_imread(image_path) - img = dbimutils.smart_24bit(img) - img = dbimutils.make_square(img, IMAGE_SIZE) - img = dbimutils.smart_resize(img, IMAGE_SIZE) - img = img.astype(np.float32) - b_imgs.append((image_path, img)) - - if len(b_imgs) >= args.batch_size: - run_batch(b_imgs) - b_imgs.clear() - if len(b_imgs) > 0: - run_batch(b_imgs) - - print("done!") - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s", - help="model path to load / 読み込むモデルファイル") - parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv", - help="csv file for tags / タグ一覧のCSVファイル") - parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子") - parser.add_argument("--debug", action="store_true", help="debug mode") - - args = parser.parse_args() - - # スペルミスしていたオプションを復元する - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - - main(args) diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..7bf54834 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup, find_packages + +setup(name = "library", packages = find_packages()) \ No newline at end of file diff --git a/script/code-snippet.ipynb b/tools/code-snippet.ipynb similarity index 100% rename from script/code-snippet.ipynb rename to tools/code-snippet.ipynb diff --git a/convert_diffusers20_original_sd/convert_diffusers20_original_sd.py b/tools/convert_diffusers20_original_sd.py similarity index 98% rename from convert_diffusers20_original_sd/convert_diffusers20_original_sd.py rename to tools/convert_diffusers20_original_sd.py index 5df5f954..a3cd03fe 100644 --- a/convert_diffusers20_original_sd/convert_diffusers20_original_sd.py +++ b/tools/convert_diffusers20_original_sd.py @@ -9,7 +9,7 @@ import torch from diffusers import StableDiffusionPipeline -import model_util +import library.model_util as model_util def convert(args): @@ -48,7 +48,7 @@ def convert(args): v2_model = unet.config.cross_attention_dim == 1024 print("checking model version: model is " + ('v2' if v2_model else 'v1')) else: - v2_model = args.v1 + v2_model = not args.v1 # 変換して保存する msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py new file mode 100644 index 00000000..4d5e58d4 --- /dev/null +++ b/tools/detect_face_rotate.py @@ -0,0 +1,239 @@ +# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします +# (c) 2022 Kohya S. @kohya_ss + +# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す + +# v2: extract max face if multiple faces are found +# v3: add crop_ratio option +# v4: add multiple faces extraction and min/max size + +import argparse +import math +import cv2 +import glob +import os +from anime_face_detector import create_detector +from tqdm import tqdm +import numpy as np + +KP_REYE = 11 +KP_LEYE = 19 + +SCORE_THRES = 0.90 + + +def detect_faces(detector, image, min_size): + preds = detector(image) # bgr + # print(len(preds)) + + faces = [] + for pred in preds: + bb = pred['bbox'] + score = bb[-1] + if score < SCORE_THRES: + continue + + left, top, right, bottom = bb[:4] + cx = int((left + right) / 2) + cy = int((top + bottom) / 2) + fw = int(right - left) + fh = int(bottom - top) + + lex, ley = pred['keypoints'][KP_LEYE, 0:2] + rex, rey = pred['keypoints'][KP_REYE, 0:2] + angle = math.atan2(ley - rey, lex - rex) + angle = angle / math.pi * 180 + + faces.append((cx, cy, fw, fh, angle)) + + faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順 + return faces + + +def rotate_image(image, angle, cx, cy): + h, w = image.shape[0:2] + rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0) + + # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化 + # nh = max(h, int(w * math.sin(angle))) + # nw = max(w, int(h * math.sin(angle))) + # if nh > h or nw > w: + # pad_y = nh - h + # pad_t = pad_y // 2 + # pad_x = nw - w + # pad_l = pad_x // 2 + # m = np.array([[0, 0, pad_l], + # [0, 0, pad_t]]) + # rot_mat = rot_mat + m + # h, w = nh, nw + # cx += pad_l + # cy += pad_t + + result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT) + return result, cx, cy + + +def process(args): + assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません" + assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" + + # アニメ顔検出モデルを読み込む + print("loading face detector.") + detector = create_detector('yolov3') + + # cropの引数を解析する + if args.crop_size is None: + crop_width = crop_height = None + else: + tokens = args.crop_size.split(',') + assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください" + crop_width, crop_height = [int(t) for t in tokens] + + if args.crop_ratio is None: + crop_h_ratio = crop_v_ratio = None + else: + tokens = args.crop_ratio.split(',') + assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください" + crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] + + # 画像を処理する + print("processing.") + output_extension = ".png" + + os.makedirs(args.dst_dir, exist_ok=True) + paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \ + glob.glob(os.path.join(args.src_dir, "*.webp")) + for path in tqdm(paths): + basename = os.path.splitext(os.path.basename(path))[0] + + # image = cv2.imread(path) # 日本語ファイル名でエラーになる + image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED) + if len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + if image.shape[2] == 4: + print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") + image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい + + h, w = image.shape[:2] + + faces = detect_faces(detector, image, args.multiple_faces) + for i, face in enumerate(faces): + cx, cy, fw, fh, angle = face + face_size = max(fw, fh) + if args.min_size is not None and face_size < args.min_size: + continue + if args.max_size is not None and face_size >= args.max_size: + continue + face_suffix = f"_{i+1:02d}" if args.multiple_faces else "" + + # オプション指定があれば回転する + face_img = image + if args.rotate: + face_img, cx, cy = rotate_image(face_img, angle, cx, cy) + + # オプション指定があれば顔を中心に切り出す + if crop_width is not None or crop_h_ratio is not None: + cur_crop_width, cur_crop_height = crop_width, crop_height + if crop_h_ratio is not None: + cur_crop_width = int(face_size * crop_h_ratio + .5) + cur_crop_height = int(face_size * crop_v_ratio + .5) + + # リサイズを必要なら行う + scale = 1.0 + if args.resize_face_size is not None: + # 顔サイズを基準にリサイズする + scale = args.resize_face_size / face_size + if scale < cur_crop_width / w: + print( + f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") + scale = cur_crop_width / w + if scale < cur_crop_height / h: + print( + f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}") + scale = cur_crop_height / h + elif crop_h_ratio is not None: + # 倍率指定の時にはリサイズしない + pass + else: + # 切り出しサイズ指定あり + if w < cur_crop_width: + print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") + scale = cur_crop_width / w + if h < cur_crop_height: + print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") + scale = cur_crop_height / h + if args.resize_fit: + scale = max(cur_crop_width / w, cur_crop_height / h) + + if scale != 1.0: + w = int(w * scale + .5) + h = int(h * scale + .5) + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + cx = int(cx * scale + .5) + cy = int(cy * scale + .5) + fw = int(fw * scale + .5) + fh = int(fh * scale + .5) + + cur_crop_width = min(cur_crop_width, face_img.shape[1]) + cur_crop_height = min(cur_crop_height, face_img.shape[0]) + + x = cx - cur_crop_width // 2 + cx = cur_crop_width // 2 + if x < 0: + cx = cx + x + x = 0 + elif x + cur_crop_width > w: + cx = cx + (x + cur_crop_width - w) + x = w - cur_crop_width + face_img = face_img[:, x:x+cur_crop_width] + + y = cy - cur_crop_height // 2 + cy = cur_crop_height // 2 + if y < 0: + cy = cy + y + y = 0 + elif y + cur_crop_height > h: + cy = cy + (y + cur_crop_height - h) + y = h - cur_crop_height + face_img = face_img[y:y + cur_crop_height] + + # # debug + # print(path, cx, cy, angle) + # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) + # cv2.imshow("image", crp) + # if cv2.waitKey() == 27: + # break + # cv2.destroyAllWindows() + + # debug + if args.debug: + cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20) + + _, buf = cv2.imencode(output_extension, face_img) + with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f: + buf.tofile(f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ") + parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ") + parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する") + parser.add_argument("--resize_fit", action="store_true", + help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする") + parser.add_argument("--resize_face_size", type=int, default=None, + help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする") + parser.add_argument("--crop_size", type=str, default=None, + help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す") + parser.add_argument("--crop_ratio", type=str, default=None, + help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す") + parser.add_argument("--min_size", type=int, default=None, + help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)") + parser.add_argument("--max_size", type=int, default=None, + help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)") + parser.add_argument("--multiple_faces", action="store_true", + help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す") + parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します") + args = parser.parse_args() + + process(args) diff --git a/script/merge_block_weighted.py b/tools/merge_block_weighted.py similarity index 100% rename from script/merge_block_weighted.py rename to tools/merge_block_weighted.py diff --git a/script/merge_vae.py b/tools/merge_vae.py similarity index 100% rename from script/merge_vae.py rename to tools/merge_vae.py diff --git a/train_db_fixed/train_db_fixed.py b/train_db.py similarity index 99% rename from train_db_fixed/train_db_fixed.py rename to train_db.py index ce40aa4e..1dde882c 100644 --- a/train_db_fixed/train_db_fixed.py +++ b/train_db.py @@ -9,7 +9,7 @@ # v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization # add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface # support save_ever_n_epochs/save_state in DiffUsers model -# fix the issue that prior_loss_weight is applyed to train images +# fix the issue that prior_loss_weight is applied to train images # v12: stop train text encode, tqdm smoothing # v13: bug fix # v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae @@ -43,7 +43,7 @@ from einops import rearrange from torch import einsum -import model_util +import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -986,7 +986,7 @@ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): # v12で更新:clip_sample=Falseに # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ + # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')  noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) @@ -1011,6 +1011,7 @@ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): if stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") text_encoder.train(False) + text_encoder.requires_grad_(False) with accelerator.accumulate(unet): with torch.no_grad(): @@ -1156,7 +1157,7 @@ def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): parser.add_argument("--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") diff --git a/train_db_fixed/model_util.py b/train_db_fixed/model_util.py deleted file mode 100644 index f3453025..00000000 --- a/train_db_fixed/model_util.py +++ /dev/null @@ -1,1182 +0,0 @@ -# v1: split from train_db_fixed.py. -# v2: support safetensors - -import math -import os -import torch -from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel -from safetensors.torch import load_file, save_file - -# DiffUsers版StableDiffusionのモデルパラメータ -NUM_TRAIN_TIMESTEPS = 1000 -BETA_START = 0.00085 -BETA_END = 0.0120 - -UNET_PARAMS_MODEL_CHANNELS = 320 -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] -UNET_PARAMS_IMAGE_SIZE = 32 # unused -UNET_PARAMS_IN_CHANNELS = 4 -UNET_PARAMS_OUT_CHANNELS = 4 -UNET_PARAMS_NUM_RES_BLOCKS = 2 -UNET_PARAMS_CONTEXT_DIM = 768 -UNET_PARAMS_NUM_HEADS = 8 - -VAE_PARAMS_Z_CHANNELS = 4 -VAE_PARAMS_RESOLUTION = 256 -VAE_PARAMS_IN_CHANNELS = 3 -VAE_PARAMS_OUT_CH = 3 -VAE_PARAMS_CH = 128 -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] -VAE_PARAMS_NUM_RES_BLOCKS = 2 - -# V2 -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] -V2_UNET_PARAMS_CONTEXT_DIM = 1024 - -# Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" -DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" - - -# region StableDiffusion->Diffusersの変換コード -# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0) - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) - - -def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - unet_key = "model.diffusion_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params - - block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) - - return config - - -def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config - - -def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] - return text_model_dict - - -def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith("cond_stage_model"): - return None - - # common conversion - key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") - key = key.replace("cond_stage_model.model.", "text_model.") - - if "resblocks" in key: - # resblocks conversion - key = key.replace(".resblocks.", ".layers.") - if ".ln_" in key: - key = key.replace(".ln_", ".layer_norm") - elif ".mlp." in key: - key = key.replace(".c_fc.", ".fc1.") - key = key.replace(".c_proj.", ".fc2.") - elif '.attn.out_proj' in key: - key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in SD: {key}") - elif '.positional_embedding' in key: - key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") - elif '.ln_final' in key: - key = key.replace(".ln_final", ".final_layer_norm") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) - - key_suffix = ".weight" if "weight" in key else ".bias" - key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") - key_pfx = key_pfx.replace("_weight", "") - key_pfx = key_pfx.replace("_bias", "") - key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") - new_sd[key_pfx + "q_proj" + key_suffix] = values[0] - new_sd[key_pfx + "k_proj" + key_suffix] = values[1] - new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) - return new_sd - -# endregion - - -# region Diffusers->StableDiffusion の変換コード -# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - -def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ["proj_in.weight", "proj_out.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - - -def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("time_embed.0.weight", "time_embedding.linear_1.weight"), - ("time_embed.0.bias", "time_embedding.linear_1.bias"), - ("time_embed.2.weight", "time_embedding.linear_2.weight"), - ("time_embed.2.bias", "time_embedding.linear_2.bias"), - ("input_blocks.0.0.weight", "conv_in.weight"), - ("input_blocks.0.0.bias", "conv_in.bias"), - ("out.0.weight", "conv_norm_out.weight"), - ("out.0.bias", "conv_norm_out.bias"), - ("out.2.weight", "conv_out.weight"), - ("out.2.bias", "conv_out.bias"), - ] - - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ("in_layers.0", "norm1"), - ("in_layers.2", "conv1"), - ("out_layers.0", "norm2"), - ("out_layers.3", "conv2"), - ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut"), - ] - - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." - sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." - unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." - sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." - unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." - sd_up_res_prefix = f"output_blocks.{3*i + j}.0." - unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." - sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." - unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." - sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." - unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." - unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) - - hf_mid_atn_prefix = "mid_block.attentions.0." - sd_mid_atn_prefix = "middle_block.1." - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) - - for j in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{j}." - sd_mid_res_prefix = f"middle_block.{2*j}." - unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if "resnets" in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} - - if v2: - conv_transformer_to_linear(new_state_dict) - - return new_state_dict - - -# ================# -# VAE Conversion # -# ================# - -def reshape_weight_for_sd(w): - # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) - - -def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ("nin_shortcut", "conv_shortcut"), - ("norm_out", "conv_norm_out"), - ("mid.attn_1.", "mid_block.attentions.0."), - ] - - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." - sd_down_prefix = f"encoder.down.{i}.block.{j}." - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - - if i < 3: - hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." - sd_downsample_prefix = f"down.{i}.downsample." - vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - - hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." - sd_upsample_prefix = f"up.{3-i}.upsample." - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." - sd_up_prefix = f"decoder.up.{3-i}.block.{j}." - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f"mid_block.resnets.{i}." - sd_mid_res_prefix = f"mid.block_{i+1}." - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ("norm.", "group_norm."), - ("q.", "query."), - ("k.", "key."), - ("v.", "value."), - ("proj_out.", "proj_attn."), - ] - - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if "attentions" in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ["q", "k", "v", "proj_out"] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f"mid.attn_1.{weight_name}.weight" in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) - - return new_state_dict - - -# endregion - -# region 自作のモデル読み書きなど - -def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' - - -def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), - ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), - ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') - ] - - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, "cpu") - else: - checkpoint = torch.load(ckpt_path, map_location="cpu") - if "state_dict" in checkpoint: - state_dict = checkpoint["state_dict"] - else: - state_dict = checkpoint - checkpoint = None - - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from):] - key_reps.append((key, new_key)) - - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] - - return checkpoint, state_dict - - -# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) - - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print("loading u-net:", info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) - - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=512, - torch_dtype="float32", - transformers_version="4.25.0.dev0", - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print("loading text encoder:", info) - - return text_model, vae, unet - - -def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): - def convert_key(key): - # position_idsの除去 - if ".position_ids" in key: - return None - - # common - key = key.replace("text_model.encoder.", "transformer.") - key = key.replace("text_model.", "") - if "layers" in key: - # resblocks conversion - key = key.replace(".layers.", ".resblocks.") - if ".layer_norm" in key: - key = key.replace(".layer_norm", ".ln_") - elif ".mlp." in key: - key = key.replace(".fc1.", ".c_fc.") - key = key.replace(".fc2.", ".c_proj.") - elif '.self_attn.out_proj' in key: - key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f"unexpected key in DiffUsers model: {key}") - elif '.position_embedding' in key: - key = key.replace("embeddings.position_embedding.weight", "positional_embedding") - elif '.token_embedding' in key: - key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") - elif 'final_layer_norm' in key: - key = key.replace("final_layer_norm", "ln_final") - return key - - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] - - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace("q_proj", "k_proj") - key_v = key.replace("q_proj", "v_proj") - - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") - new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print("make dummy weights for resblock.23, text_projection and logit scale.") - keys = list(new_sd.keys()) - for key in keys: - if key.startswith("transformer.resblocks.22."): - new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる - - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) - new_sd['logit_scale'] = torch.tensor(1) - - return new_sd - - -def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if "state_dict" in state_dict: - del state_dict["state_dict"] - else: - # 新しく作る - assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" - checkpoint = {} - state_dict = {} - strict = False - - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert not strict or key in state_dict, f"Illegal key in save SD: {key}" - if save_dtype is not None: - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd("model.diffusion_model.", unet_state_dict) - - # Convert the text encoder model - if v2: - make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) - update_sd("cond_stage_model.model.", text_enc_dict) - else: - text_enc_dict = text_encoder.state_dict() - update_sd("cond_stage_model.transformer.", text_enc_dict) - - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd("first_stage_model.", vae_dict) - - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count - - -def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") - if vae is None: - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = "first_stage_model." - - -def load_vae(vae_id, dtype): - print(f"load VAE: {vae_id}") - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) - except EnvironmentError as e: - print(f"exception occurs in loading vae: {e}") - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith(".bin"): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - return vae - - -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - -# endregion - - -def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) - - resos = set() - - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) - - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) - - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) - - size += divisible - - resos = list(resos) - resos.sort() - - aspect_ratios = [w / h for w, h in resos] - return resos, aspect_ratios - - -if __name__ == '__main__': - resos, aspect_ratios = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - print(aspect_ratios) - - ars = set() - for ar in aspect_ratios: - if ar in ars: - print("error! duplicate ar:", ar) - ars.add(ar) diff --git a/train_network.py b/train_network.py new file mode 100644 index 00000000..f26ced8b --- /dev/null +++ b/train_network.py @@ -0,0 +1,1453 @@ +import gc +import importlib +import json +import time +from typing import NamedTuple +from torch.autograd.function import Function +import argparse +import glob +import math +import os +import random + +from tqdm import tqdm +import torch +from torchvision import transforms +from accelerate import Accelerator +from accelerate.utils import set_seed +from transformers import CLIPTokenizer +import diffusers +from diffusers import DDPMScheduler, StableDiffusionPipeline +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum + +import library.model_util as model_util + +# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + +# checkpointファイル名 +EPOCH_STATE_NAME = "epoch-{:06d}-state" +LAST_STATE_NAME = "last-state" + +EPOCH_FILE_NAME = "epoch-{:06d}" +LAST_FILE_NAME = "last" + + +# region dataset + +class ImageInfo(): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: tuple[int, int] = None + self.bucket_reso: tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + batch_index: int + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: + super().__init__() + self.tokenizer: CLIPTokenizer = tokenizer + self.max_token_length = max_token_length + self.shuffle_caption = shuffle_caption + self.shuffle_keep_tokens = shuffle_keep_tokens + self.width, self.height = resolution + self.face_crop_aug_range = face_crop_aug_range + self.flip_aug = flip_aug + self.color_aug = color_aug + self.debug_dataset = debug_dataset + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + # augmentation + flip_p = 0.5 if flip_aug else 0.0 + if color_aug: + # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る + self.aug = albu.Compose([ + albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33), + albu.HorizontalFlip(p=flip_p) + ], p=1.) + elif flip_aug: + self.aug = albu.Compose([ + albu.HorizontalFlip(p=flip_p) + ], p=1.) + else: + self.aug = None + + self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) + + self.image_data: dict[str, ImageInfo] = {} + + def process_caption(self, caption): + if self.shuffle_caption: + tokens = caption.strip().split(",") + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + else: + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = keep_tokens + tokens + caption = ",".join(tokens).strip() + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo): + self.image_data[info.image_key] = info + + def make_buckets(self, enable_bucket, min_size, max_size): + ''' + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + ''' + + self.enable_bucket = enable_bucket + + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + # bucketingを用意する + if enable_bucket: + bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) + else: + # bucketはひとつだけ、すべての画像は同じ解像度 + bucket_resos = [(self.width, self.height)] + bucket_aspect_ratios = [self.width / self.height] + bucket_aspect_ratios = np.array(bucket_aspect_ratios) + + # bucketを作成する + if enable_bucket: + img_ar_errors = [] + for image_info in self.image_data.values(): + # bucketを決める + image_width, image_height = image_info.image_size + aspect_ratio = image_width / image_height + ar_errors = bucket_aspect_ratios - aspect_ratio + + bucket_id = np.abs(ar_errors).argmin() + image_info.bucket_reso = bucket_resos[bucket_id] + + ar_error = ar_errors[bucket_id] + img_ar_errors.append(ar_error) + else: + reso = (self.width, self.height) + for image_info in self.image_data.values(): + image_info.bucket_reso = reso + + # 画像をbucketに分割する + self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] + reso_to_index = {} + for i, reso in enumerate(bucket_resos): + reso_to_index[reso] = i + + for image_info in self.image_data.values(): + bucket_index = reso_to_index[image_info.bucket_reso] + for _ in range(image_info.num_repeats): + self.buckets[bucket_index].append(image_info.image_key) + + if enable_bucket: + print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") + for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + + # 参照用indexを作る + self.buckets_indices: list(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def resize_and_trim(self, image, reso): + image_height, image_width = image.shape[0:2] + ar_img = image_width / image_height + ar_reso = reso[0] / reso[1] + if ar_img > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ + f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def cache_latents(self, vae): + print("caching latents.") + for info in tqdm(self.image_data.values()): + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.resize_and_trim(image, info.bucket_reso) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if self.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if self.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + .5) + nw = int(width * scale + .5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + .5) + face_cy = int(face_cy * scale + .5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if self.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1:p1 + target_size, :] + else: + image = image[:, p1:p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + return np.load(npz_file)['arr_0'] + + def __len__(self): + return self._length + + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() + + bucket = self.buckets[self.buckets_indices[index].bucket_index] + image_index = self.buckets_indices[index].batch_index * self.batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index:image_index + self.batch_size]: + image_info = self.image_data[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.resize_and_trim(img, image_info.bucket_reso) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + if self.aug is not None: + img = self.aug(image=img)['image'] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(image_info.caption) + captions.append(caption) + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) + example['input_ids'] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.random_crop = random_crop + self.latents_cache = None + self.enable_bucket = False + + def read_caption(img_path): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: + lines = f.readlines() + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(dir): + if not os.path.isdir(dir): + # print(f"ignore file: {dir}") + return 0, [], [] + + tokens = os.path.basename(dir).split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, [], [] + + caption_by_folder = '_'.join(tokens[1:]) + img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ + glob.glob(os.path.join(dir, "*.webp")) + print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path) + captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + + return n_repeats, img_paths, captions + + print("prepare train images.") + train_dirs = os.listdir(train_data_dir) + num_train_images = 0 + for dir in train_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) + num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, False, img_path) + self.register_image(info) + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + # reg imageは数を数えて学習画像と同じ枚数にする + num_reg_images = 0 + if reg_data_dir: + print("prepare reg images.") + reg_infos: list[ImageInfo] = [] + + reg_dirs = os.listdir(reg_data_dir) + for dir in reg_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) + num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, True, img_path) + reg_infos.append(info) + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + n = 0 + while n < num_train_images: + for info in reg_infos: + self.register_image(info) + n += info.num_repeats + if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない + break + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__(self, metadata, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + + self.metadata = metadata + self.train_data_dir = train_data_dir + self.batch_size = batch_size + + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + + glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) + assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" + abs_path = abs_path[0] + + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') + + if not self.color_aug: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) + + self.register_image(image_info) + self.num_train_images = len(metadata) * dataset_repeats + self.num_reg_images = 0 + + # check existence of all npz files + if not self.color_aug: + npz_any = False + npz_all = True + for image_info in self.image_data.values(): + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if self.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") + elif not npz_all: + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + # check min/max bucket size + sizes = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + + if sizes is None: + self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated + else: + self.min_bucket_reso = min(sizes) + self.max_bucket_reso = max(sizes) + + def image_key_to_npz_file(self, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + '.npz' + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + '_flip.npz' + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # image_key is relative path + npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') + npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + +# endregion + + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(Function): + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = (q.shape[-1] ** -0.5) + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) + + p = exp_attn_weights / lc + + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, 'b h n d -> b n (h d)') + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format + # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format + del q_in, k_in, v_in + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers +# endregion + + +def collate_fn(examples): + return examples[0] + + +def train(args): + cache_latents = args.cache_latents + + # latentsをキャッシュする場合のオプション設定を確認する + if cache_latents: + assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" + + # その他のオプション設定を確認する + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + use_dreambooth_method = args.in_json is None + + # モデル形式のオプション設定を確認する: + load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + + # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) + + # tokenizerを読み込む + print("prepare tokenizer") + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + + if args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + # 学習データを用意する + assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + resolution = tuple([int(r) for r in args.resolution.split(',')]) + if len(resolution) == 1: + resolution = (resolution[0], resolution[0]) + assert len(resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.face_crop_aug_range is not None: + face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) + assert len( + face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + face_crop_aug_range = None + + # データセットを準備する + if use_dreambooth_method: + print("Use DreamBooth method.") + train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, + tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, + resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset) + else: + print("Train with captions.") + + # メタデータを読み込む + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return + + if args.color_aug: + print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") + + train_dataset = FineTuningDataset(metadata, args.train_batch_size, args.train_data_dir, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + + if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso): + print(f"using bucket info in metadata / メタデータ内のbucket情報を使います") + args.min_bucket_reso = train_dataset.min_bucket_reso + args.max_bucket_reso = train_dataset.max_bucket_reso + args.enable_bucket = True + print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}") + + if args.enable_bucket: + assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso) + + if args.debug_dataset: + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") + k = 0 + for example in train_dataset: + if example['latents'] is not None: + print("sample has latents from npz file") + for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') + if example['images'] is not None: + im = example['images'][j] + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or example['images'] is None: + break + return + + if len(train_dataset) == 0: + print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + return + + # acceleratorを準備する + print("prepare accelerator") + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + # モデルを読み込む + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + else: + print("load Diffusers pretrained models") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + # モデルに xformers とか memory efficient attention を組み込む + replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # prepare network + print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + net_kwargs = {} + if args.network_args is not None: + for net_arg in args.network_args: + key, value = net_arg.split('=') + net_kwargs[key] = value + + network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) + if network is None: + return + + if args.network_weights is not None: + print("load network weights from:", args.network_weights) + network.load_weights(args.network_weights) + + train_unet = not args.network_train_text_encoder_only + train_text_encoder = not args.network_train_unet_only + network.apply_to(text_encoder, unet, train_text_encoder, train_unet) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") + + # 8-bit Adamを使う + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print("use 8-bit Adam optimizer") + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + + # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 + optimizer = optimizer_class(trainable_params, lr=args.learning_rate) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + + # lr schedulerを用意する + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする + if args.full_fp16: + assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") + # unet.to(weight_dtype) + # text_encoder.to(weight_dtype) + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + if train_unet and train_text_encoder: + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) + elif train_unet: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler) + elif train_text_encoder: + text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, network, optimizer, train_dataloader, lr_scheduler) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler) + + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.eval() + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.eval() + + network.prepare_grad_etc(text_encoder, unet) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") + print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_train_timesteps=1000, clip_sample=False) + + if accelerator.is_main_process: + accelerator.init_trackers("network_train") + + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + + # 指定したステップ数までText Encoderを学習する:epoch最初の状態 + network.on_epoch_start(text_encoder, unet) + + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(network): + with torch.no_grad(): + # latentに変換 + if batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要 + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + current_loss = loss.detach().item() + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step+1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch+1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: + print("saving checkpoint.") + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) + unwrap_model(network).save_weights(ckpt_file, save_dtype) + + if args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + + is_main_process = accelerator.is_main_process + if is_main_process: + network = unwrap_model(network) + + accelerator.end_training() + + if args.save_state: + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) + print(f"save trained model to {ckpt_file}") + network.save_weights(ckpt_file, save_dtype) + print("model saved.") + + +if __name__ == '__main__': + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + parser.add_argument("--network_weights", type=str, default=None, + help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--keep_tokens", type=int, default=None, + help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--dataset_repeats", type=int, default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument("--face_crop_aug_range", type=str, default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") + parser.add_argument("--random_crop", action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--resolution", type=str, default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--vae", type=str, default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--cache_latents", action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") + # parser.add_argument("--stop_text_encoder_training", type=int, default=None, + # help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') + parser.add_argument("--network_dim", type=int, default=None, + help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_args", type=str, default=None, nargs='*', + help='additional argmuments for network (key=value) / ネットワークへの追加の引数') + parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") + parser.add_argument("--network_train_text_encoder_only", action="store_true", + help="only training Text Encoder part / Text Encoder関連部分のみ学習する") + + args = parser.parse_args() + train(args)