diff --git a/library/sd3_models.py b/library/sd3_models.py index 8b90205db..2f3c82eed 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -1017,22 +1017,35 @@ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: b patched_size = patched_size_ break if patched_size is None: - raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") + # use largest latent size + patched_size = self.resolution_area_to_latent_size[-1][1] pos_embed = self.resolution_pos_embeds[patched_size] - pos_embed_size = round(math.sqrt(pos_embed.shape[1])) + pos_embed_size = round(math.sqrt(pos_embed.shape[1])) # max size, patched_size * POS_EMBED_MAX_RATIO if h > pos_embed_size or w > pos_embed_size: # # fallback to normal pos_embed # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) # extend pos_embed size logger.warning( - f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." + f"Add new pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." ) - pos_embed_size = max(h, w) - pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) + patched_size = max(h, w) + grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) + pos_embed_size = grid_size + pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) self.resolution_pos_embeds[patched_size] = pos_embed - logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") + logger.info(f"Added pos_embed for size {patched_size}x{patched_size}") + + # print(torch.allclose(pos_embed.to(torch.float32).cpu(), self.pos_embed.to(torch.float32).cpu(), atol=5e-2)) + # diff = pos_embed.to(torch.float32).cpu() - self.pos_embed.to(torch.float32).cpu() + # print(diff.abs().max(), diff.abs().mean()) + + # insert to resolution_area_to_latent_size, by adding and sorting + area = pos_embed_size**2 + self.resolution_area_to_latent_size.append((area, patched_size)) + self.resolution_area_to_latent_size = sorted(self.resolution_area_to_latent_size) if not random_crop: top = (pos_embed_size - h) // 2