From 1c0ae306e551ede5bd162819debb4d80a7fe620b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 15:43:02 -0500 Subject: [PATCH] Add missing functions for training batch --- train_network.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index ce34f26d3..377ddf48e 100644 --- a/train_network.py +++ b/train_network.py @@ -318,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, # endregion def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: - + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -1333,6 +1333,11 @@ def remove_model(old_ckpt_name): continue with accelerator.accumulate(training_model): + on_step_start_for_network(text_encoder, unet) + + # temporary, for batch processing + self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) accelerator.backward(loss) if accelerator.sync_gradients: