Skip to content

Commit

Permalink
Cleanup unused code and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 6, 2025
1 parent f4840ef commit 1c63e7c
Showing 1 changed file with 70 additions and 15 deletions.
85 changes: 70 additions & 15 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,27 @@ def on_step_start(self, args, accelerator, network, text_encoders, unet, batch,

# endregion

def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:

def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_sd.SdTextEncodingStrategy,
tokenize_strategy: strategy_sd.SdTokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
) -> torch.Tensor:
"""
Process a batch for the network
"""
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
Expand All @@ -334,7 +353,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au

latents = self.shift_scale_latents(args, latents)


text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
Expand Down Expand Up @@ -371,13 +389,6 @@ def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: Au
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]

batch_size = latents.shape[0]


# Predict the noise residual
# and add noise to the latents
# with noise offset and/or multires noise if specified

# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
Expand Down Expand Up @@ -1288,7 +1299,23 @@ def remove_model(old_ckpt_name):
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
loss = self.process_batch(batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
)

accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
Expand Down Expand Up @@ -1366,12 +1393,26 @@ def remove_model(old_ckpt_name):
if val_step >= validation_steps:
break

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)

loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
)

val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })

if is_tracking:
logs = {"loss/current_val_loss": loss.detach().item()}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
Expand All @@ -1397,7 +1438,21 @@ def remove_model(old_ckpt_name):
if val_step >= validation_steps:
break

loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
)

current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
Expand Down

0 comments on commit 1c63e7c

Please sign in to comment.