Skip to content

Commit

Permalink
fix: resolve model corruption issue with pos_embed when using --enabl…
Browse files Browse the repository at this point in the history
…e_scaled_pos_embed
  • Loading branch information
kohya-ss committed Dec 7, 2024
1 parent 8b36d90 commit 6bee18d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The command to install PyTorch is as follows:

### Recent Updates

Dec 7, 2024:
- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.

Dec 3, 2024:

Expand Down
6 changes: 4 additions & 2 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,8 +870,10 @@ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Opti
self.use_scaled_pos_embed = use_scaled_pos_embed

if self.use_scaled_pos_embed:
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved
# self.pos_embed = None
# move pos_embed to CPU to free up memory up to 0.4 GB
self.pos_embed = self.pos_embed.cpu()

# remove duplicates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
Expand Down

0 comments on commit 6bee18d

Please sign in to comment.