Skip to content

Commit

Permalink
feat: merge with the latest sd-scripts update
Browse files Browse the repository at this point in the history
  • Loading branch information
Linaqruf committed Mar 20, 2023
1 parent cfd47f6 commit d40c4ee
Show file tree
Hide file tree
Showing 17 changed files with 6,115 additions and 3,945 deletions.
736 changes: 378 additions & 358 deletions fine_tune.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List
from tqdm import tqdm
import library.train_util as train_util

import os

def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
Expand All @@ -29,6 +29,9 @@ def main(args):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()

if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)

image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
Expand Down
5 changes: 4 additions & 1 deletion finetune/merge_dd_tags_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List
from tqdm import tqdm
import library.train_util as train_util

import os

def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
Expand All @@ -29,6 +29,9 @@ def main(args):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()

if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)

image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
Expand Down
52 changes: 21 additions & 31 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@
import random
import re

import toml

import diffusers
import numpy as np
import torch
Expand Down Expand Up @@ -1651,10 +1649,11 @@ def get_unweighted_text_embeddings(
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
if text_input_chunk[:, -1] != eos and text_input_chunk[:, -1] != pad: # 最後に普通の文字がある
text_input_chunk[:, -1] = eos
if text_input_chunk[:, 1] == pad: # BOSだけであとはPAD
text_input_chunk[:, 1] = eos
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos

if clip_skip is None or clip_skip == 1:
text_embedding = pipe.text_encoder(text_input_chunk)[0]
Expand Down Expand Up @@ -2278,13 +2277,26 @@ def resize_images(imgs, size):
mask_images = l

# 画像サイズにオプション指定があるときはリサイズする
if init_images is not None and args.W is not None and args.H is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
if args.W is not None and args.H is not None:
if init_images is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
if mask_images is not None:
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))

if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
mask_images = None

prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
Expand Down Expand Up @@ -2774,27 +2786,5 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')

parser.add_argument("--config_file", type=str, default=None, help="using .toml instead of args to pass hyperparameter")

args = parser.parse_args()

if args.config_file:
config_path = args.config_file + ".toml" if not args.config_file.endswith(".toml") else args.config_file
if os.path.exists(config_path):
print(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
config_dict = toml.load(f)

ignore_nesting_dict = {}
for section_name, section_dict in config_dict.items():
for key, value in section_dict.items():
ignore_nesting_dict[key] = value

config_args = argparse.Namespace(**ignore_nesting_dict)
args = parser.parse_args(namespace=config_args)
args.config_file = args.config_file.split(".")[0]
print(args.config_file)
else:
print(f"{config_path} not found.")

main(args)
6 changes: 4 additions & 2 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int)
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
}

# options handled by argparse but not handled by user config
Expand Down Expand Up @@ -283,7 +283,9 @@ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) ->
def __merge_dict(*dict_list: dict) -> dict:
merged = {}
for schema in dict_list:
merged.update(schema)
# merged |= schema
for k, v in schema.items():
merged[k] = v
return merged


Expand Down
Loading

0 comments on commit d40c4ee

Please sign in to comment.