Skip to content

Commit

Permalink
Add handler to sample and save on demand
Browse files Browse the repository at this point in the history
  • Loading branch information
fabbarix committed Dec 21, 2024
1 parent e896539 commit 04ed9ed
Show file tree
Hide file tree
Showing 16 changed files with 101 additions and 66 deletions.
9 changes: 6 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler

init_ipex()

Expand Down Expand Up @@ -344,6 +345,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
signal_handler = SignalHandler()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -426,12 +428,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
global_step += 1

train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet, force_sample=signal_handler.should_sample()
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
Expand Down
9 changes: 6 additions & 3 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler

init_ipex()

Expand Down Expand Up @@ -579,6 +580,7 @@ def grad_hook(parameter: torch.Tensor):

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
signal_handler = SignalHandler()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -705,12 +707,13 @@ def grad_hook(parameter: torch.Tensor):

optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, force_sample=signal_handler.should_sample()
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
Expand Down
5 changes: 2 additions & 3 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,11 @@ def cache_text_encoder_outputs_if_needed(
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux, force_sample=False):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs, force_sample=force_sample
)
# return

Expand Down
24 changes: 13 additions & 11 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
controlnet=None,
force_sample=False
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if not force_sample:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
Expand Down
22 changes: 12 additions & 10 deletions library/sd3_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,20 +381,22 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
force_sample=False
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if not force_sample:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
Expand Down
22 changes: 12 additions & 10 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6113,25 +6113,27 @@ def sample_images_common(
unet,
prompt_replacement=None,
controlnet=None,
force_sample=False
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
TODO Use strategies here
"""

if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if not force_sample:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return

logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
Expand Down
9 changes: 6 additions & 3 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler

init_ipex()

Expand Down Expand Up @@ -730,6 +731,7 @@ def grad_hook(parameter: torch.Tensor):

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
signal_handler = SignalHandler()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -883,12 +885,13 @@ def grad_hook(parameter: torch.Tensor):

optimizer_eval_fn()
sd3_train_utils.sample_images(
accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, mmdit, vae, [clip_l, clip_g, t5xxl], sample_prompts_te_outputs, force_sample=signal_handler.should_sample()
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
sd3_train_utils.save_sd3_model_on_epoch_end_or_stepwise(
args,
Expand Down
4 changes: 2 additions & 2 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ def cache_text_encoder_outputs_if_needed(
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit, force_sample=False):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)

sd3_train_utils.sample_images(
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs, force_sample=force_sample
)

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
Expand Down
8 changes: 6 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler


init_ipex()
Expand Down Expand Up @@ -622,6 +623,7 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
signal_handler = SignalHandler()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -770,11 +772,13 @@ def optimizer_hook(parameter: torch.Tensor):
tokenizers,
[text_encoder1, text_encoder2],
unet,
force_sample=signal_handler.should_sample()
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
Expand Down
9 changes: 6 additions & 3 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler

init_ipex()

Expand Down Expand Up @@ -456,7 +457,7 @@ def remove_model(old_ckpt_name):
unet,
controlnet=control_net,
)

signal_handler = SignalHandler()
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
Expand Down Expand Up @@ -578,11 +579,13 @@ def remove_model(old_ckpt_name):
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
force_sample=signal_handler.should_sample()
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(control_net))
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def call_unet(
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, force_sample=False):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, force_sample=force_sample)


def setup_parser() -> argparse.ArgumentParser:
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond
return noise_pred

def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement, force_sample=False
):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement, force_sample=force_sample
)

def save_weights(self, file, updated_embs, save_dtype, metadata):
Expand Down
10 changes: 7 additions & 3 deletions train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.signal_handler import SignalHandler

init_ipex()

Expand Down Expand Up @@ -431,7 +432,8 @@ def remove_model(old_ckpt_name):
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)


signal_handler = SignalHandler()
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
Expand Down Expand Up @@ -539,11 +541,13 @@ def remove_model(old_ckpt_name):
text_encoder,
unet,
controlnet=controlnet,
force_sample=signal_handler.should_sample(),
)

signal_handler.reset_sample()
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if signal_handler.should_save() or (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0):
accelerator.wait_for_everyone()
signal_handler.reset_save()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(
Expand Down
Loading

0 comments on commit 04ed9ed

Please sign in to comment.