Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance code readability with type hints and docstrings #400

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 69 additions & 28 deletions ldm/models/diffusion/dpm_solver/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""SAMPLING ONLY."""
import torch

from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
Expand All @@ -8,46 +7,87 @@
"v": "v"
}


class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
def __init__(self, model: torch.nn.Module, device: torch.device = torch.device("cuda"), **kwargs) -> None:
"""
Initialize the DPMSolverSampler.

Args:
model (torch.nn.Module): The model to use for sampling.
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
def register_buffer(self, name: str, attr: torch.Tensor) -> None:
"""
Register a buffer in the module.

Args:
name (str): The name of the buffer.
attr (torch.Tensor): The tensor to register as a buffer.
"""
if isinstance(attr, torch.Tensor):
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)

@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
S: int,
batch_size: int,
shape: tuple,
conditioning: dict = None,
callback: callable = None,
normals_sequence: list = None,
img_callback: callable = None,
quantize_x0: bool = False,
eta: float = 0.,
mask: torch.Tensor = None,
x0: torch.Tensor = None,
temperature: float = 1.,
noise_dropout: float = 0.,
score_corrector: callable = None,
corrector_kwargs: dict = None,
verbose: bool = True,
x_T: torch.Tensor = None,
log_every_t: int = 100,
unconditional_guidance_scale: float = 1.,
unconditional_conditioning: torch.Tensor = None,
**kwargs) -> tuple:
"""
Perform sampling using the DPM Solver.

Args:
S (int): Number of steps.
batch_size (int): Batch size for sampling.
shape (tuple): Shape of the samples (C, H, W).
conditioning (dict, optional): Conditioning information. Defaults to None.
callback (callable, optional): Callback function. Defaults to None.
normals_sequence (list, optional): Sequence of normals. Defaults to None.
img_callback (callable, optional): Image callback function. Defaults to None.
quantize_x0 (bool, optional): Flag for quantizing x0. Defaults to False.
eta (float, optional): Eta parameter. Defaults to 0..
mask (torch.Tensor, optional): Mask tensor. Defaults to None.
x0 (torch.Tensor, optional): Initial x0 tensor. Defaults to None.
temperature (float, optional): Temperature parameter. Defaults to 1..
noise_dropout (float, optional): Noise dropout parameter. Defaults to 0..
score_corrector (callable, optional): Score corrector. Defaults to None.
corrector_kwargs (dict, optional): Keyword arguments for the score corrector. Defaults to None.
verbose (bool, optional): Verbose flag. Defaults to True.
x_T (torch.Tensor, optional): Initial x_T tensor. Defaults to None.
log_every_t (int, optional): Log interval. Defaults to 100.
unconditional_guidance_scale (float, optional): Guidance scale for unconditional sampling. Defaults to 1..
unconditional_conditioning (torch.Tensor, optional): Conditioning tensor for unconditional sampling. Defaults to None.
**kwargs: Additional keyword arguments.

Returns:
tuple: Sampled tensor and additional information.
"""
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
Expand Down Expand Up @@ -94,3 +134,4 @@ def sample(self,
lower_order_final=True)

return x.to(device), None