Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even 80 GB is not sufficient for SANA 4K model VAE decoding - something feels wrong? #139

Open
FurkanGozukara opened this issue Jan 9, 2025 · 12 comments
Labels
Answered Answered the question working working on this issue

Comments

@FurkanGozukara
Copy link

Tried to allocate 36.00 GiB. GPU 0 has a total capacity of 44.52 GiB of which 30.29 GiB is free

Testing on L40S and still fails

Using official pipeline : https://github.com/NVlabs/Sana/blob/main/app/sana_pipeline.py

During inference it uses around 18 GB VRAM but VAE decode is causing issue

image

SANA 2K and 1K works great

2025-01-09 21:48:36 - [Sana] - WARNING - use pe: True, position embed interpolation: 2.0, base size: 128
2025-01-09 21:48:36 - [Sana] - WARNING - attention type: linear; ffn type: glumbconv; autocast linear attn: False
2025-01-09 21:48:48 - [Sana] - INFO - use_fp32_attention: False
2025-01-09 21:48:48 - [Sana] - INFO - SanaMS:SanaMS_1600M_P1_D20,Model Parameters: 1,604,462,752
[Sana] Loading model from models/checkpoints/Sana_1600M_4Kpx_BF16.pth
2025-01-09 21:48:52 - [Sana] - INFO - Generating sample from ckpt: models/checkpoints/Sana_1600M_4Kpx_BF16.pth
2025-01-09 21:48:52 - [Sana] - WARNING - Missing keys: ['pos_embed']
2025-01-09 21:48:52 - [Sana] - WARNING - Unexpected keys: []
PORT: 15432, model_path: models/checkpoints/Sana_1600M_4Kpx_BF16.pth
super fast car
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:23<00:00,  1.41s/it]
Traceback (most recent call last):
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/blocks.py", line 2045, in process_api
    result = await self.call_function(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/blocks.py", line 1604, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 715, in async_iteration
    return await anext(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 709, in __anext__
    return await anyio.to_thread.run_sync(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2461, in run_sync_in_worker_thread
    return await future
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 962, in run
    result = context.run(func, *args)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 692, in run_sync_iterator_async
    return next(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 853, in gen_wrapper
    response = next(iterator)
  File "/workspace/Sana/app/secourses_app.py", line 346, in generate_multiple
    images, current_seed, speed_info = generate(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/app/secourses_app.py", line 276, in generate
    images = pipe(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/app/sana_pipeline.py", line 297, in forward
    sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
  File "/workspace/Sana/diffusion/model/builder.py", line 119, in vae_decode
    samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py", line 445, in decode
    x = self.decoder(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py", line 414, in forward
    x = stage(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 834, in forward
    x = op(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 780, in forward
    res = self.forward_main(x) + self.shortcut(x)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 770, in forward_main
    return self.main(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 218, in forward
    x = self.conv(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 89, in forward
    x = self.conv(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 36.00 GiB. GPU 0 has a total capacity of 44.52 GiB of which 27.03 GiB is free. Process 2674842 has 17.49 GiB memory in use. Of the allocated memory 15.63 GiB is allocated by PyTorch, and 1.34 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@FurkanGozukara
Copy link
Author

even A100 - 80 GB fails

@lawrence-cj

Starting VAE decode...
Traceback (most recent call last):
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/blocks.py", line 2045, in process_api
    result = await self.call_function(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/blocks.py", line 1604, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 715, in async_iteration
    return await anext(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 709, in __anext__
    return await anyio.to_thread.run_sync(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2461, in run_sync_in_worker_thread
    return await future
  File "/workspace/Sana/venv/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 962, in run
    result = context.run(func, *args)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 692, in run_sync_iterator_async
    return next(iterator)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/gradio/utils.py", line 853, in gen_wrapper
    response = next(iterator)
  File "/workspace/Sana/app/secourses_app.py", line 346, in generate_multiple
    images, current_seed, speed_info = generate(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/app/secourses_app.py", line 276, in generate
    images = pipe(
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/Sana/app/sana_pipeline.py", line 327, in forward
    sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
  File "/workspace/Sana/diffusion/model/builder.py", line 119, in vae_decode
    samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py", line 445, in decode
    x = self.decoder(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/efficientvit/dc_ae.py", line 414, in forward
    x = stage(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 834, in forward
    x = op(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 780, in forward
    res = self.forward_main(x) + self.shortcut(x)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 770, in forward_main
    return self.main(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 218, in forward
    x = self.conv(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/diffusion/model/dc_ae/efficientvit/models/nn/ops.py", line 89, in forward
    x = self.conv(x)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/workspace/Sana/venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 72.00 GiB. GPU 0 has a total capacity of 79.25 GiB of which 33.02 GiB is free. Process 183813 has 46.23 GiB memory in use. Of the allocated memory 19.54 GiB is allocated by PyTorch, and 26.18 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@FurkanGozukara FurkanGozukara changed the title Even 48 GB is not sufficient for SANA 4K model VAE decoding - something feels wrong? Even 80 GB is not sufficient for SANA 4K model VAE decoding - something feels wrong? Jan 9, 2025
@geronimi73
Copy link

try this, runs on an A40
relevant PR: huggingface/diffusers#10510

@FurkanGozukara
Copy link
Author

FurkanGozukara commented Jan 9, 2025

@geronimi73 thanks but it is not even merged i can't use it - non-technical people installing via my installers

@lawrence-cj please consider fixing the pipeline here

@xieenze
Copy link
Collaborator

xieenze commented Jan 9, 2025

@FurkanGozukara try to use patch_conv on the vae decoder :) https://github.com/mit-han-lab/patch_conv

@FurkanGozukara
Copy link
Author

@FurkanGozukara try to use patch_conv on the vae decoder :) https://github.com/mit-han-lab/patch_conv

interesting gonna test now

so it is model = convert_model(model, splits=4) # The only modification you need to make

how many splits should i make? why 4?

@FurkanGozukara
Copy link
Author

FurkanGozukara commented Jan 9, 2025

@xieenze i tried and not working

no errors but 0 difference

mit-han-lab/patch_conv#4


def load_model(config_path, model_path):
    global pipe
    if pipe is None:
        if torch.cuda.is_available():
            try:
                pipe = SanaPipeline(config_path)
                pipe.from_pretrained(model_path)
                pipe.register_progress_bar(gr.Progress())
                pipe = convert_model(pipe, splits=4)
                return True, "Model loaded successfully"
            except Exception as e:
                return False, f"Error loading model: {str(e)}"
    return True, "Model already loaded"

@xieenze
Copy link
Collaborator

xieenze commented Jan 9, 2025

@FurkanGozukara try only convert vae? such as below

image

@FurkanGozukara
Copy link
Author

@FurkanGozukara try only convert vae? such as below

image

thanks

i tested here and didnt work but gonna test yours now

def get_vae(name, model_path, device="cuda"):
    if name == "sdxl" or name == "sd3":
        vae = AutoencoderKL.from_pretrained(model_path).to(device).to(torch.float16)
        if name == "sdxl":
            vae.config.shift_factor = 0
        return vae
    elif "dc-ae" in name:
        print(colored(f"[DC-AE] Loading model from {model_path}", attrs=["bold"]))
        dc_ae = DCAE_HF.from_pretrained(model_path).to(device).eval()
        dc_ae = convert_model(dc_ae, splits=8)
        print("model converted build_sana_model")
        return dc_ae
    else:
        print("error load vae")
        exit()

@FurkanGozukara
Copy link
Author

@xieenze not working either

tried 4 / 8 / 16 / 32 / 64 / 128 :(

i have rtx 3090 - 24 gb

def vae_decode(name, vae, latent):
    if name == "sdxl" or name == "sd3":
        latent = (latent.detach() / vae.config.scaling_factor) + vae.config.shift_factor
        vae = convert_model(vae, splits=128)
        print("model converted build_sana_model")
        samples = vae.decode(latent).sample
    elif "dc-ae" in name:
        ae = vae
        ae = convert_model(ae, splits=128)
        print("model converted build_sana_model")
        samples = ae.decode(latent.detach() / ae.cfg.scaling_factor)
    else:
        print("error load vae")
        exit()
    return samples

@elismasilva
Copy link

@xieenze i think will be necessary implement a tiledvae module for this, I think SUPIR implement this but i dont know if the arch is similar.

@lawrence-cj
Copy link
Collaborator

huggingface/diffusers#10510

We are working on it. Will fix when this PR is merged.

@lawrence-cj lawrence-cj added Answered Answered the question working working on this issue labels Jan 10, 2025
@FurkanGozukara
Copy link
Author

huggingface/diffusers#10510

We are working on it. Will fix when this PR is merged.

so you are not going to update SANA pipeline here? we have to move to diffusers i guess?

@lawrence-cj can you add a diffusers demo with all features here? https://github.com/NVlabs/Sana/tree/main/app

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Answered Answered the question working working on this issue
Projects
None yet
Development

No branches or pull requests

5 participants