Skip to content

Commit

Permalink
fix to work
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 29, 2024
1 parent 2238b94 commit 744cf03
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __next__(self):
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")

new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
Expand Down Expand Up @@ -166,6 +167,10 @@ def collate_fn_remove_corrupted(batch):


def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"

# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
Expand Down Expand Up @@ -436,7 +441,7 @@ def run_batch(
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": [image_size.width, image_size.height]}
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
Expand Down Expand Up @@ -464,6 +469,7 @@ def run_batch(

# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
Expand All @@ -480,7 +486,7 @@ def run_batch(
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = ArchiveImageLoader(image_paths, args.batch_size)
loader = ArchiveImageLoader([str(p) for p in image_paths], args.batch_size)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
Expand Down

0 comments on commit 744cf03

Please sign in to comment.