Skip to content

Commit

Permalink
Merge pull request #43 from sberbank-ai/feature/dwt_vae
Browse files Browse the repository at this point in the history
Feature/dwt vae
  • Loading branch information
shonenkov authored Nov 9, 2021
2 parents a23a834 + a1980cf commit 5488289
Show file tree
Hide file tree
Showing 10 changed files with 533 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[run]
omit =
# omit this single file
rudalle/vae/pytorch_wavelets_utils.py
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)

```
pip install rudalle==0.0.1rc6
pip install rudalle==0.0.1rc7
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
Expand Down Expand Up @@ -92,6 +92,7 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]

### 🚀 Contributors 🚀

- [@bes](https://github.com/bes-dev) shared [great idea and realization with IDWT](https://github.com/bes-dev/vqvae_dwt_distiller.pytorch) for decoding images with higher quality 512x512! 😈💪
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ transformers~=4.10.2
youtokentome~=1.0.6
omegaconf>=2.0.0
einops~=0.3.2
PyWavelets==1.1.1
torch
torchvision
matplotlib
2 changes: 1 addition & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
'image_prompts',
]

__version__ = '0.0.1-rc6'
__version__ = '0.0.1-rc7'
14 changes: 10 additions & 4 deletions rudalle/vae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
from .model import VQGanGumbelVAE


def get_vae(pretrained=True, cache_dir='/tmp/rudalle'):
def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
# TODO
config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml'))
vae = VQGanGumbelVAE(config)
vae = VQGanGumbelVAE(config, dwt=dwt)
if pretrained:
repo_id = 'shonenkov/rudalle-utils'
filename = 'vqgan.gumbelf8-sber.model.ckpt'
if dwt:
filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt'
else:
filename = 'vqgan.gumbelf8-sber.model.ckpt'
cache_dir = join(cache_dir, 'vae')
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
if dwt:
vae.load_state_dict(checkpoint['state_dict'])
else:
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
print('vae --> ready')
return vae
102 changes: 102 additions & 0 deletions rudalle/vae/decoder_dwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
import pywt
import torch
import torch.nn as nn
from taming.modules.diffusionmodules.model import Decoder

from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int


class DecoderDWT(nn.Module):
def __init__(self, ddconfig, embed_dim):
super().__init__()
if ddconfig.out_ch != 12:
ddconfig.out_ch = 12
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
self.decoder = Decoder(**ddconfig)
self.idwt = DWTInverse(mode='zero', wave='db1')

def forward(self, x):
# x = self.post_quant_conv(x)
freq = self.decoder(x)
img = self.dwt_to_img(freq)
return img

def dwt_to_img(self, img):
b, c, h, w = img.size()
low = img[:, :3, :, :]
high = img[:, 3:, :, :].view(b, 3, 3, h, w)
return self.idwt((low, [high]))


class DWTInverse(nn.Module):
""" Performs a 2d DWT Inverse reconstruction of an image
Args:
wave (str or pywt.Wavelet): Which wavelet to use
C: deprecated, will be removed in future
"""

def __init__(self, wave='db1', mode='zero', trace_model=False):
super().__init__()
if isinstance(wave, str):
wave = pywt.Wavelet(wave)
if isinstance(wave, pywt.Wavelet):
g0_col, g1_col = wave.rec_lo, wave.rec_hi
g0_row, g1_row = g0_col, g1_col
else:
if len(wave) == 2:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = g0_col, g1_col
elif len(wave) == 4:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = wave[2], wave[3]
# Prepare the filters
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
self.register_buffer('g0_col', filts[0])
self.register_buffer('g1_col', filts[1])
self.register_buffer('g0_row', filts[2])
self.register_buffer('g1_row', filts[3])
self.mode = mode
self.trace_model = trace_model

def forward(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
W_{in}')` and yh is a list of bandpass tensors of shape
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
the format returned by DWTForward
Returns:
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
Note:
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
downsampled shapes of the DWT pyramid.
Note:
Can have None for any of the highpass scales and will treat the
values as zeros (not in an efficient way though).
"""
yl, yh = coeffs
ll = yl
mode = mode_to_int(self.mode)

# Do a multilevel inverse transform
for h in yh[::-1]:
if h is None:
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
ll.shape[-1], device=ll.device)

# 'Unpad' added dimensions
if ll.shape[-2] > h.shape[-2]:
ll = ll[..., :-1, :]
if ll.shape[-1] > h.shape[-1]:
ll = ll[..., :-1]
if not self.trace_model:
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
else:
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
return ll
15 changes: 11 additions & 4 deletions rudalle/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
from einops import rearrange
from taming.modules.diffusionmodules.model import Encoder, Decoder

from .decoder_dwt import DecoderDWT


class VQGanGumbelVAE(torch.nn.Module):

def __init__(self, config):
def __init__(self, config, dwt=False):
super().__init__()
model = GumbelVQ(
ddconfig=config.model.params.ddconfig,
n_embed=config.model.params.n_embed,
embed_dim=config.model.params.embed_dim,
kl_weight=config.model.params.kl_weight,
dwt=dwt,
)
self.model = model
self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2))
Expand Down Expand Up @@ -79,11 +82,12 @@ def forward(self, z, temp=None, return_logits=False):

class GumbelVQ(nn.Module):

def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8):
def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8):
super().__init__()
z_channels = ddconfig['z_channels']
self.dwt = dwt
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig)
self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
Expand All @@ -95,6 +99,9 @@ def encode(self, x):
return quant, emb_loss, info

def decode(self, quant):
quant = self.post_quant_conv(quant)
if self.dwt:
quant = self.decoder.post_quant_conv(quant)
else:
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
Loading

0 comments on commit 5488289

Please sign in to comment.