Skip to content

Commit

Permalink
bring in prediction of v objective, combining the findings from progr…
Browse files Browse the repository at this point in the history
…essive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video
  • Loading branch information
lucidrains committed Oct 29, 2022
1 parent 9f37705 commit fbba0f9
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 17 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```

```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```

*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
74 changes: 58 additions & 16 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,20 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def q_sample(self, x_start, t, noise=None):
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))

return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)

def calculate_v(self, x_start, t, noise = None):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)

def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape = x_from.shape
noise = default(noise, lambda: torch.randn_like(x_from))
Expand All @@ -638,6 +644,12 @@ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):

return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha

def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)

def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
Expand Down Expand Up @@ -1146,6 +1158,7 @@ def __init__(
image_cond_drop_prob = None,
loss_type = "l2",
predict_x_start = True,
predict_v = False,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
Expand Down Expand Up @@ -1197,6 +1210,7 @@ def __init__(
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.

self.predict_x_start = predict_x_start
self.predict_v = predict_v # takes precedence over predict_x_start

# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132

Expand Down Expand Up @@ -1226,7 +1240,9 @@ def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = Fal

pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)

if self.predict_x_start:
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
Expand Down Expand Up @@ -1299,7 +1315,9 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal

# derive x0

if self.predict_x_start:
if self.predict_v:
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
elif self.predict_x_start:
x_start = pred
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
Expand All @@ -1314,7 +1332,7 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal

# predict noise

if self.predict_x_start:
if self.predict_x_start or self.predict_v:
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
else:
pred_noise = pred
Expand Down Expand Up @@ -1372,7 +1390,12 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
if self.predict_x_start and self.training_clamp_l2norm:
pred = self.l2norm_clamp_embed(pred)

target = noise if not self.predict_x_start else image_embed
if self.predict_v:
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
elif self.predict_x_start:
target = image_embed
else:
target = noise

loss = self.noise_scheduler.loss_fn(pred, target)
return loss
Expand Down Expand Up @@ -2448,6 +2471,7 @@ def __init__(
loss_type = 'l2',
beta_schedule = None,
predict_x_start = False,
predict_v = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
Expand Down Expand Up @@ -2620,6 +2644,10 @@ def __init__(

self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))

# predict v

self.predict_v = cast_tuple(predict_v, len(unets))

# input image range

self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
Expand Down Expand Up @@ -2731,14 +2759,16 @@ def dynamic_threshold(self, x):
x = x.clamp(-s, s) / s
return x

def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))

pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)

if predict_x_start:
if predict_v:
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
elif predict_x_start:
x_start = pred
else:
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
Expand All @@ -2765,9 +2795,9 @@ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodin
return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
Expand All @@ -2782,6 +2812,7 @@ def p_sample_loop_ddpm(
image_embed,
noise_scheduler,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
Expand Down Expand Up @@ -2840,6 +2871,7 @@ def p_sample_loop_ddpm(
lowres_cond_img = lowres_cond_img,
lowres_noise_level = lowres_noise_level,
predict_x_start = predict_x_start,
predict_v = predict_v,
noise_scheduler = noise_scheduler,
learned_variance = learned_variance,
clip_denoised = clip_denoised
Expand All @@ -2865,6 +2897,7 @@ def p_sample_loop_ddim(
timesteps,
eta = 1.,
predict_x_start = False,
predict_v = False,
learned_variance = False,
clip_denoised = True,
lowres_cond_img = None,
Expand Down Expand Up @@ -2926,7 +2959,9 @@ def p_sample_loop_ddim(

# predict x0

if predict_x_start:
if predict_v:
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
elif predict_x_start:
x_start = pred
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
Expand All @@ -2938,8 +2973,8 @@ def p_sample_loop_ddim(

# predict noise

if predict_x_start:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
if predict_x_start or predict_v:
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
else:
pred_noise = pred

Expand Down Expand Up @@ -2975,7 +3010,7 @@ def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):

return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)

def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
noise = default(noise, lambda: torch.randn_like(x_start))

# normalize to [-1, 1]
Expand Down Expand Up @@ -3020,7 +3055,12 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres

pred, _ = self.parse_unet_output(learned_variance, unet_output)

target = noise if not predict_x_start else x_start
if predict_v:
target = noise_scheduler.calculate_v(x_start, times, noise)
elif predict_x_start:
target = x_start
else:
target = noise

loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
Expand Down Expand Up @@ -3106,7 +3146,7 @@ def sample(
num_unets = self.num_unets
cond_scale = cast_tuple(cond_scale, num_unets)

for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
if unet_number < start_at_unet_number:
continue # It's the easiest way to do it

Expand Down Expand Up @@ -3142,6 +3182,7 @@ def sample(
text_encodings = text_encodings,
cond_scale = unet_cond_scale,
predict_x_start = predict_x_start,
predict_v = predict_v,
learned_variance = learned_variance,
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
Expand Down Expand Up @@ -3181,6 +3222,7 @@ def forward(
lowres_conditioner = self.lowres_conds[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
predict_v = self.predict_v[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
learned_variance = self.learned_variance[unet_index]
b, c, h, w, device, = *image.shape, image.device
Expand Down Expand Up @@ -3219,7 +3261,7 @@ def forward(
image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)

losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)

if not return_lowres_cond_image:
return losses
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.9'
__version__ = '1.11.1'

0 comments on commit fbba0f9

Please sign in to comment.