Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddim makes the generation worse #231

Open
YUHANG-Ma opened this issue Aug 18, 2022 · 3 comments
Open

ddim makes the generation worse #231

YUHANG-Ma opened this issue Aug 18, 2022 · 3 comments

Comments

@YUHANG-Ma
Copy link

Hi, I met an issue that when I use ddim for the decoder sampling, the pics don't look good.
image
When I change the sample step to 1000, it comes to the following result.
image
Could I ask how to fix it?

The following is the ddim part of my code.

` def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.betas.device

    b = shape[0]
    img = torch.randn(shape, device = device)
    timesteps = 250
    times = torch.linspace(0., 1000, steps = timesteps + 2)[:-1]
    

    times = list(reversed(times.int().tolist()))
    time_pairs = list(zip(times[:-1], times[1:]))
    print(time_pairs)
    alphas = self.alphas_cumprod_prev
    if not is_latent_diffusion:
        lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)

    for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
        alpha = alphas[time]
        alpha_next = alphas[time_next]
        
        # print("alpha_next",alpha_next)
        # print("alpha_next1",alpha_next1)
        

        time_cond = torch.full((b,), time, device = device, dtype = torch.long)


        pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)

        if learned_variance:
            pred, _ = pred.chunk(2, dim = 1)

        if predict_x_start:
            x_start = pred
            pred_noise = self.predict_noise_from_start(img, t = time_cond, x0 = pred)
        else:
            x_start = self.predict_start_from_noise(img, t = time_cond, noise = pred)
            pred_noise = pred

        if clip_denoised:
            s = 1.
            # clip by threshold, depending on whether static or dynamic
            x_start = x_start.clamp(-s, s) / s

        c1 = 1 * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
        c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
        noise = torch.randn_like(img) if time_next > 0 else 0.

        img = x_start * alpha_next.sqrt() + \
              c1 * noise + \
              c2 * pred_noise

    

    img = self.unnormalize_img(img)
    return img`
@FTKyaoyuan
Copy link

can you tell me the dataset you used thanks

@YUHANG-Ma
Copy link
Author

can you tell me the dataset you used thanks

pics from Internet

@FTKyaoyuan
Copy link

can you tell me the dataset you used thanks

pics from Internet

Can I see the code you trained

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants