Skip to content

Commit

Permalink
fix: improve pos_embed handling for oversized images and update resol…
Browse files Browse the repository at this point in the history
…ution_area_to_latent_size, when sample image size > train image size
  • Loading branch information
kohya-ss committed Nov 30, 2024
1 parent 2a61fc0 commit 9c885e5
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions library/sd3_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,22 +1017,35 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b
patched_size = patched_size_
break
if patched_size is None:
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
# raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
# use largest latent size
patched_size = self.resolution_area_to_latent_size[-1][1]

pos_embed = self.resolution_pos_embeds[patched_size]
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO
if h > pos_embed_size or w > pos_embed_size:
# # fallback to normal pos_embed
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
# extend pos_embed size
logger.warning(
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
)
pos_embed_size = max(h, w)
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
patched_size = max(h, w)
grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
pos_embed_size = grid_size
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
self.resolution_pos_embeds[patched_size] = pos_embed
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
logger.info(f"Added pos_embed for size {patched_size}x{patched_size}")

# print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2))
# diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu()
# print(diff.abs().max(), diff.abs().mean())

# insert to resolution_area_to_latent_size, by adding and sorting
area = pos_embed_size**2
self.resolution_area_to_latent_size.append((area, patched_size))
self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size)

if not random_crop:
top = (pos_embed_size - h) // 2
Expand Down

0 comments on commit 9c885e5

Please sign in to comment.