Skip to content

Commit

Permalink
Merge pull request #1745 from bmaltais/dev2
Browse files Browse the repository at this point in the history
v22.3.0
  • Loading branch information
bmaltais authored Dec 6, 2023
2 parents 8fb0b31 + 06eed69 commit b3ea59c
Show file tree
Hide file tree
Showing 24 changed files with 850 additions and 225 deletions.
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v22.2.2
v22.3.0
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,20 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b


## Change History
* 2023/12/06 (v22.3.0)
- Merge sd-scripts updates:
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
- The default values are same as the previous version.
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
- Add GLoRA support

* 2023/12/03 (v22.2.2)
- Update Lycoris module to 2.0.0 (https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/README.md)
- Update Lycoris merge and extract tools
Expand Down
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
Expand Down
24 changes: 16 additions & 8 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def main(args):

tag_freq = {}

undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))

def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
Expand Down Expand Up @@ -194,7 +196,7 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
Expand All @@ -203,18 +205,18 @@ def run_batch(path_imgs):

if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)

# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]

caption_file = os.path.splitext(image_path)[0] + args.caption_extension

tag_text = ", ".join(combined_tags)
tag_text = caption_separator.join(combined_tags)

if args.append_tags:
# Check if file exists
Expand All @@ -224,13 +226,13 @@ def run_batch(path_imgs):
existing_content = f.read().strip("\n") # Remove newlines

# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]

# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]

# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
tag_text = caption_separator.join(existing_tags + new_tags)

with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
Expand Down Expand Up @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)

return parser

Expand Down
88 changes: 86 additions & 2 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction

from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -378,7 +378,7 @@ def __init__(
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
clip_model: CLIPModel,
Expand Down Expand Up @@ -2196,6 +2196,7 @@ def main(args):
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)

# VAEを読み込む
if args.vae is not None:
Expand Down Expand Up @@ -2352,13 +2353,20 @@ def __getattr__(self, item):
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
vae.eval()

text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)

text_encoder.eval()
unet.eval()

if clip_model is not None:
clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None:
vgg16_model.to(dtype).to(device)
vgg16_model.eval()

# networkを組み込む
if args.network_module:
Expand Down Expand Up @@ -2501,6 +2509,10 @@ def __getattr__(self, item):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()

# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)

# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
Expand Down Expand Up @@ -3085,6 +3097,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
clip_prompt = None
network_muls = None

# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio

prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
Expand Down Expand Up @@ -3156,10 +3175,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print(f"network mul: {network_muls}")
continue

# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue

m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue

m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue

m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue

m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)

# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)

# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
Expand Down Expand Up @@ -3509,6 +3569,30 @@ def setup_parser() -> argparse.ArgumentParser:
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)

return parser


Expand Down
1 change: 1 addition & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
color_aug: bool = False
flip_aug: bool = False
Expand Down
9 changes: 6 additions & 3 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas):
noise_scheduler.alphas_cumprod = alphas_cumprod


def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss

Expand Down
Loading

0 comments on commit b3ea59c

Please sign in to comment.