From 6bee18db4fbf62ebd2a1da88a5851c48f2e06c54 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 15:12:27 +0900 Subject: [PATCH] fix: resolve model corruption issue with pos_embed when using --enable_scaled_pos_embed --- README.md | 2 ++ library/sd3_models.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f02725191..6162359d1 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 7, 2024: +- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/library/sd3_models.py b/library/sd3_models.py index 2f3c82eed..e4a931861 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -870,8 +870,10 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # remove pos_embed to free up memory up to 0.4 GB - self.pos_embed = None + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes))