diff --git a/library/train_util.py b/library/train_util.py index 0f16a4f31..0907a8c03 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5905,27 +5905,6 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor return timesteps -def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: - if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): - return None - - b_size = timesteps.shape[0] - if args.huber_schedule == "exponential": - alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - result = torch.exp(-alpha * timesteps) * args.huber_scale - elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): - raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) - sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 - result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c - result = result.to(timesteps.device) - elif args.huber_schedule == "constant": - result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) - else: - raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - return result def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor: @@ -6004,6 +5983,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch. return noise, noisy_latents, timesteps +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: + if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): + return None + + b_size = timesteps.shape[0] + if args.huber_schedule == "exponential": + alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps + result = torch.exp(-alpha * timesteps) * args.huber_scale + elif args.huber_schedule == "snr": + if not hasattr(noise_scheduler, "alphas_cumprod"): + raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c + result = result.to(timesteps.device) + elif args.huber_schedule == "constant": + result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device) + else: + raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") + + return result + + def get_noisy_latents(args, noise: torch.FloatTensor, noise_scheduler: DDPMScheduler, latents: torch.FloatTensor, timesteps: torch.IntTensor) -> torch.FloatTensor: """ Add noise to the latents according to the noise magnitude at each timestep