Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss #1864

Open
wants to merge 78 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
5b19bda
Add validation loss
rockerBOO Nov 5, 2023
33c311e
new ratio code
rockerBOO Nov 5, 2023
3de9e6c
Add validation split of datasets
rockerBOO Nov 5, 2023
a93c524
Update args to validation_seed and validation_split
rockerBOO Nov 5, 2023
c892521
Add process_batch for train_network
rockerBOO Nov 5, 2023
e545fdf
Removed/cleanup a line
rockerBOO Nov 5, 2023
9c591bd
Remove unnecessary subset line from collate
rockerBOO Nov 5, 2023
569ca72
Set grad enabled if is_train and train_text_encoder
rockerBOO Nov 7, 2023
b558a5b
val
gesen2egee Mar 9, 2024
78cfb01
improve
gesen2egee Mar 10, 2024
923b761
Update train_network.py
gesen2egee Mar 10, 2024
47359b8
Update train_network.py
gesen2egee Mar 10, 2024
a51723c
fix timesteps
gesen2egee Mar 11, 2024
7d84ac2
only use train subset to val
gesen2egee Mar 11, 2024
befbec5
Update train_network.py
gesen2egee Mar 11, 2024
63e58f7
Update train_network.py
gesen2egee Mar 11, 2024
a6c41c6
Update train_network.py
gesen2egee Mar 11, 2024
bd7e229
fix
gesen2egee Mar 13, 2024
5d7ed0d
Merge remote-tracking branch 'kohya-ss/dev' into val
gesen2egee Mar 13, 2024
d05965d
Update train_network.py
gesen2egee Mar 13, 2024
b5e8045
fix control net
gesen2egee Mar 16, 2024
086f600
Merge branch 'main' into val
gesen2egee Apr 10, 2024
36d4023
Update config_util.py
gesen2egee Apr 10, 2024
229c5a3
Update train_util.py
gesen2egee Apr 10, 2024
3b251b7
Update config_util.py
gesen2egee Apr 10, 2024
459b125
Update config_util.py
gesen2egee Apr 10, 2024
89ad69b
Update train_util.py
gesen2egee Apr 11, 2024
fde8026
Update config_util.py
gesen2egee Apr 11, 2024
31507b9
Remove unnecessary is_train changes and use apply_debiased_estimation…
gesen2egee Aug 2, 2024
1db4951
Update train_db.py
gesen2egee Aug 4, 2024
6816217
Update train_db.py
gesen2egee Aug 4, 2024
96eb74f
Update train_db.py
gesen2egee Aug 4, 2024
b9bdd10
Update train_network.py
gesen2egee Aug 4, 2024
3d68754
Update train_db.py
gesen2egee Aug 4, 2024
a593e83
Update train_network.py
gesen2egee Aug 4, 2024
f6dbf7c
Update train_network.py
gesen2egee Aug 4, 2024
aa850aa
Update train_network.py
gesen2egee Aug 4, 2024
cdb2d9c
Update train_network.py
gesen2egee Aug 4, 2024
3028027
Update train_network.py
gesen2egee Oct 4, 2024
dece2c3
Update train_db.py
gesen2egee Oct 4, 2024
05bb918
Add Validation loss for LoRA training
hinablue Dec 27, 2024
62164e5
Change val loss calculate method
hinablue Dec 27, 2024
64bd531
Split val latents/batch and pick up val latents shape size which equa…
hinablue Dec 28, 2024
cb89e02
Change val latent loss compare
hinablue Dec 28, 2024
8743532
val
gesen2egee Mar 9, 2024
449c1c5
Adding modified train_util and config_util
rockerBOO Jan 2, 2025
7f6e124
Merge branch 'gesen2egee/val' into validation-loss-upstream
rockerBOO Jan 3, 2025
d23c732
Merge remote-tracking branch 'hina/feature/val-loss' into validation-…
rockerBOO Jan 3, 2025
7470173
Remove defunct code for train_controlnet.py
rockerBOO Jan 3, 2025
534059d
Typos and lingering is_train
rockerBOO Jan 3, 2025
c8c3569
Cleanup order, types, print to logger
rockerBOO Jan 3, 2025
fbfc275
Update text for train/reg with repeats
rockerBOO Jan 3, 2025
58bfa36
Add seed help clarifying info
rockerBOO Jan 3, 2025
6604b36
Remove duplicate assignment
rockerBOO Jan 3, 2025
0522070
Fix training, validation split, revert to using upstream implemenation
rockerBOO Jan 3, 2025
695f389
Move get_huber_threshold_if_needed
rockerBOO Jan 3, 2025
1f9ba40
Add step break for validation epoch. Remove unused variable
rockerBOO Jan 3, 2025
1c0ae30
Add missing functions for training batch
rockerBOO Jan 3, 2025
bbf6bbd
Use self.get_noise_pred_and_target and drop fixed timesteps
rockerBOO Jan 6, 2025
f4840ef
Revert train_db.py
rockerBOO Jan 6, 2025
1c63e7c
Cleanup unused code and formatting
rockerBOO Jan 6, 2025
c64d1a2
Add validate_every_n_epochs, change name validate_every_n_steps
rockerBOO Jan 6, 2025
f885029
Fix validate epoch, cleanup imports
rockerBOO Jan 6, 2025
fcb2ff0
Clean up some validation help documentation
rockerBOO Jan 6, 2025
742bee9
Set validation steps in multiple lines for readability
rockerBOO Jan 6, 2025
1231f51
Remove unused train_util code, fix accelerate.log for wandb, add init…
rockerBOO Jan 8, 2025
556f3f1
Fix documentation, remove unused function, fix bucket reso for sd1.5,…
rockerBOO Jan 8, 2025
9fde0d7
Handle tuple return from generate_dataset_group_by_blueprint
rockerBOO Jan 8, 2025
1e61392
Revert bucket_reso_steps to correct 64
rockerBOO Jan 8, 2025
d6f158d
Fix incorrect destructoring for load_abritrary_dataset
rockerBOO Jan 8, 2025
264167f
Apply is_training_dataset only to DreamBoothDataset. Add validation_s…
rockerBOO Jan 9, 2025
4c61adc
Add divergence to logs
rockerBOO Jan 12, 2025
2bbb40c
Fix regularization images with validation
rockerBOO Jan 12, 2025
0456858
Fix validate_every_n_steps always running first step
rockerBOO Jan 12, 2025
ee9265c
Fix validate_every_n_steps for gradient accumulation
rockerBOO Jan 12, 2025
25929dd
Remove Validating... print to fix output layout
rockerBOO Jan 12, 2025
b489082
Disable repeats for validation datasets
rockerBOO Jan 12, 2025
c04e5df
Fix loss recorder on 0. Fix validation for cached runs. Assert on val…
rockerBOO Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
3 changes: 2 additions & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
3 changes: 2 additions & 1 deletion flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def train(args):
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
15 changes: 10 additions & 5 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -36,8 +36,8 @@ def __init__(self):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)

if args.fp8_base_unet:
Expand Down Expand Up @@ -80,6 +80,8 @@ def assert_extra_args(self, args, train_dataset_group):
args.blocks_to_swap = 18 # 18 is safe for most cases

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this

def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
Expand Down Expand Up @@ -339,6 +341,7 @@ def get_noise_pred_and_target(
network,
weight_dtype,
train_unet,
is_train=True
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -375,7 +378,7 @@ def get_noise_pred_and_target(
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with accelerator.autocast():
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
Expand Down Expand Up @@ -420,7 +423,9 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)

with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred
Expand Down
198 changes: 110 additions & 88 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class BaseSubsetParams:
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0


@dataclass
Expand Down Expand Up @@ -102,6 +104,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0


@dataclass
Expand All @@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0



@dataclass
class FineTuningDatasetParams(BaseDatasetParams):
batch_size: int = 1
Expand Down Expand Up @@ -234,6 +237,8 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"validation_seed": int,
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
}
Expand Down Expand Up @@ -462,119 +467,136 @@ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):

return default_value


def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []

for dataset_blueprint in dataset_group_blueprint.datasets:
extra_dataset_params = {}

if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": True}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
continue

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""
),
" ",
)
# if the dataset isn't setting a validation split, there is no current validation dataset
if dataset_blueprint.params.validation_split == 0.0:
continue

if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
extra_dataset_params = {}
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
# DreamBooth datasets support splitting training and validation datasets
extra_dataset_params = {"is_training_dataset": False}
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

logger.info(f"{info}")
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
val_datasets.append(dataset)

def print_info(_datasets, dataset_type: str):
info = ""
for i, dataset in enumerate(_datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
""")

if dataset.enable_bucket:
info += indent(dedent(f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""), " ")
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(dedent(f"""\
[Subset {j} of {dataset_type} {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""), " ")

if is_dreambooth:
info += indent(dedent(f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")

logger.info(info)

print_info(datasets, "Dataset")

if len(val_datasets) > 0:
print_info(val_datasets, "Validation Dataset")

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no

for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
logger.info(f"[Prepare dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Prepare validation dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)


def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
Expand Down
Loading
Loading