-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: test euler negative and max (#10)
- Loading branch information
Showing
2 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
241 changes: 241 additions & 0 deletions
241
stablepy/diffusers_vanilla/extra_scheduler/scheduling_euler_discrete_variants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
from typing import Optional, Tuple, Union | ||
|
||
from diffusers.utils import logging | ||
from diffusers.utils.torch_utils import randn_tensor | ||
import torch | ||
import math | ||
from diffusers.schedulers.scheduling_euler_discrete import ( | ||
EulerDiscreteScheduler, | ||
EulerDiscreteSchedulerOutput, | ||
) | ||
|
||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
class EulerDiscreteSchedulerNegative(EulerDiscreteScheduler): | ||
|
||
def step( | ||
self, | ||
model_output: torch.Tensor, | ||
timestep: Union[float, torch.Tensor], | ||
sample: torch.Tensor, | ||
s_churn: float = 0.0, | ||
s_tmin: float = 0.0, | ||
s_tmax: float = float("inf"), | ||
s_noise: float = 1.0, | ||
generator: Optional[torch.Generator] = None, | ||
return_dict: bool = True, | ||
) -> Union[EulerDiscreteSchedulerOutput, Tuple]: | ||
""" | ||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | ||
process from the learned model outputs (most often the predicted noise). | ||
Args: | ||
model_output (`torch.Tensor`): | ||
The direct output from learned diffusion model. | ||
timestep (`float`): | ||
The current discrete timestep in the diffusion chain. | ||
sample (`torch.Tensor`): | ||
A current instance of a sample created by the diffusion process. | ||
s_churn (`float`): | ||
s_tmin (`float`): | ||
s_tmax (`float`): | ||
s_noise (`float`, defaults to 1.0): | ||
Scaling factor for noise added to the sample. | ||
generator (`torch.Generator`, *optional*): | ||
A random number generator. | ||
return_dict (`bool`): | ||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or | ||
tuple. | ||
Returns: | ||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: | ||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is | ||
returned, otherwise a tuple is returned where the first element is the sample tensor. | ||
""" | ||
|
||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): | ||
raise ValueError( | ||
( | ||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | ||
" one of the `scheduler.timesteps` as a timestep." | ||
), | ||
) | ||
|
||
if not self.is_scale_input_called: | ||
logger.warning( | ||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | ||
"See `StableDiffusionPipeline` for a usage example." | ||
) | ||
|
||
if self.step_index is None: | ||
self._init_step_index(timestep) | ||
|
||
# Upcast to avoid precision issues when computing prev_sample | ||
sample = sample.to(torch.float32) | ||
|
||
sigma = self.sigmas[self.step_index] | ||
|
||
gamma = max(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 | ||
|
||
sigma_hat = sigma * (gamma + 1) | ||
|
||
if gamma > 0: | ||
noise = randn_tensor( | ||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator | ||
) | ||
eps = noise * s_noise | ||
sample = sample - eps * (sigma_hat**2 - sigma**2) ** 0.5 | ||
|
||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | ||
# backwards compatibility | ||
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": | ||
pred_original_sample = model_output | ||
elif self.config.prediction_type == "epsilon": | ||
pred_original_sample = sample - sigma_hat * model_output | ||
elif self.config.prediction_type == "v_prediction": | ||
# denoised = model_output * c_out + input * c_skip | ||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | ||
else: | ||
raise ValueError( | ||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | ||
) | ||
|
||
# 2. Convert to an ODE derivative | ||
derivative = (sample - pred_original_sample) / sigma_hat | ||
|
||
dt = self.sigmas[self.step_index + 1] - sigma_hat | ||
|
||
if self.sigmas[self.step_index + 1] > 0 and self.step_index // 2 == 1: | ||
prev_sample = -sample - derivative * dt | ||
else: | ||
prev_sample = sample + derivative * dt | ||
|
||
# Cast sample back to model compatible dtype | ||
prev_sample = prev_sample.to(model_output.dtype) | ||
|
||
# upon completion increase step index by one | ||
self._step_index += 1 | ||
|
||
if not return_dict: | ||
return ( | ||
prev_sample, | ||
pred_original_sample, | ||
) | ||
|
||
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) | ||
|
||
|
||
class EulerDiscreteSchedulerMax(EulerDiscreteScheduler): | ||
|
||
def step( | ||
self, | ||
model_output: torch.Tensor, | ||
timestep: Union[float, torch.Tensor], | ||
sample: torch.Tensor, | ||
s_churn: float = 0.0, | ||
s_tmin: float = 0.0, | ||
s_tmax: float = float("inf"), | ||
s_noise: float = 1.0, | ||
generator: Optional[torch.Generator] = None, | ||
return_dict: bool = True, | ||
) -> Union[EulerDiscreteSchedulerOutput, Tuple]: | ||
""" | ||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion | ||
process from the learned model outputs (most often the predicted noise). | ||
Args: | ||
model_output (`torch.Tensor`): | ||
The direct output from learned diffusion model. | ||
timestep (`float`): | ||
The current discrete timestep in the diffusion chain. | ||
sample (`torch.Tensor`): | ||
A current instance of a sample created by the diffusion process. | ||
s_churn (`float`): | ||
s_tmin (`float`): | ||
s_tmax (`float`): | ||
s_noise (`float`, defaults to 1.0): | ||
Scaling factor for noise added to the sample. | ||
generator (`torch.Generator`, *optional*): | ||
A random number generator. | ||
return_dict (`bool`): | ||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or | ||
tuple. | ||
Returns: | ||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: | ||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is | ||
returned, otherwise a tuple is returned where the first element is the sample tensor. | ||
""" | ||
|
||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): | ||
raise ValueError( | ||
( | ||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | ||
" one of the `scheduler.timesteps` as a timestep." | ||
), | ||
) | ||
|
||
if not self.is_scale_input_called: | ||
logger.warning( | ||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. " | ||
"See `StableDiffusionPipeline` for a usage example." | ||
) | ||
|
||
if self.step_index is None: | ||
self._init_step_index(timestep) | ||
|
||
# Upcast to avoid precision issues when computing prev_sample | ||
sample = sample.to(torch.float32) | ||
|
||
sigma = self.sigmas[self.step_index] | ||
|
||
gamma = max(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 | ||
|
||
sigma_hat = sigma * (gamma + 1) | ||
|
||
if gamma > 0: | ||
noise = randn_tensor( | ||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator | ||
) | ||
eps = noise * s_noise | ||
sample = sample - eps * (sigma_hat**2 - sigma**2) ** 0.5 | ||
|
||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for | ||
# backwards compatibility | ||
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": | ||
pred_original_sample = model_output | ||
elif self.config.prediction_type == "epsilon": | ||
pred_original_sample = sample - sigma_hat * model_output | ||
elif self.config.prediction_type == "v_prediction": | ||
# denoised = model_output * c_out + input * c_skip | ||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) | ||
else: | ||
raise ValueError( | ||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" | ||
) | ||
|
||
# 2. Convert to an ODE derivative | ||
derivative = (sample - pred_original_sample) / sigma_hat | ||
|
||
dt = self.sigmas[self.step_index + 1] - sigma_hat | ||
|
||
prev_sample = sample + (math.cos(self.step_index + 1) / (self.step_index + 1) + 1) * derivative * dt | ||
|
||
# Cast sample back to model compatible dtype | ||
prev_sample = prev_sample.to(model_output.dtype) | ||
|
||
# upon completion increase step index by one | ||
self._step_index += 1 | ||
|
||
if not return_dict: | ||
return ( | ||
prev_sample, | ||
pred_original_sample, | ||
) | ||
|
||
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) |