Skip to content

Commit

Permalink
Move get_huber_threshold_if_needed
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 3, 2025
1 parent 0522070 commit 695f389
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
return timesteps


def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None

b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

return result


def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.
return noise, noisy_latents, timesteps


def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None

b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")

return result


def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor:
"""
Add noise to the latents according to the noise magnitude at each timestep
Expand Down

0 comments on commit 695f389

Please sign in to comment.