Skip to content

Commit

Permalink
Add missing functions for training batch
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 3, 2025
1 parent 1f9ba40 commit 1c0ae30
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch,
# endregion

def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:

with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
Expand Down Expand Up @@ -1333,6 +1333,11 @@ def remove_model(old_ckpt_name):
continue

with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet)

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
accelerator.backward(loss)
if accelerator.sync_gradients:
Expand Down

0 comments on commit 1c0ae30

Please sign in to comment.