diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 286c0b1b..94f129f1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -38,6 +38,8 @@ NAT = 1. / math.log(2.) +UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized']) + # helper functions def exists(val): @@ -2584,6 +2586,14 @@ def get_unet(self, unet_number): index = unet_number - 1 return self.unets[index] + def parse_unet_output(self, learned_variance, output): + var_interp_frac_unnormalized = None + + if learned_variance: + output, var_interp_frac_unnormalized = output.chunk(2, dim = 1) + + return UnetOutput(output, var_interp_frac_unnormalized) + @contextmanager def one_unet_in_gpu(self, unet_number = None, unet = None): assert exists(unet_number) ^ exists(unet) @@ -2625,10 +2635,9 @@ def dynamic_threshold(self, x): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' - pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) + model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) - if learned_variance: - pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) + pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output) if predict_x_start: x_start = pred @@ -2811,10 +2820,9 @@ def p_sample_loop_ddim( self_cond = x_start if unet.self_cond else None - pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) + unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) - if learned_variance: - pred, _ = pred.chunk(2, dim = 1) + pred, _ = self.parse_unet_output(learned_variance, unet_output) if predict_x_start: x_start = pred @@ -2886,16 +2894,13 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres if unet.self_cond and random.random() < 0.5: with torch.no_grad(): - self_cond = unet(x_noisy, times, **unet_kwargs) - - if learned_variance: - self_cond, _ = self_cond.chunk(2, dim = 1) - + unet_output = unet(x_noisy, times, **unet_kwargs) + self_cond, _ = self.parse_unet_output(learned_variance, unet_output) self_cond = self_cond.detach() # forward to get model prediction - model_output = unet( + unet_output = unet( x_noisy, times, **unet_kwargs, @@ -2904,10 +2909,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres text_cond_drop_prob = self.text_cond_drop_prob, ) - if learned_variance: - pred, _ = model_output.chunk(2, dim = 1) - else: - pred = model_output + pred, _ = self.parse_unet_output(learned_variance, unet_output) target = noise if not predict_x_start else x_start @@ -2930,7 +2932,7 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres # if learning the variance, also include the extra weight kl loss true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) - model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) + model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output) # kl loss with detached model predicted mean, for stability reasons as in paper diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 4574cc89..d07785c5 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.3' +__version__ = '1.6.4'