diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..cc791c1 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[run] +omit = + # omit this single file + rudalle/vae/pytorch_wavelets_utils.py diff --git a/README.md b/README.md index dd7d907..f383ddd 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) ``` -pip install rudalle==0.0.1rc6 +pip install rudalle==0.0.1rc7 ``` ### 🤗 HF Models: [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) @@ -92,6 +92,7 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky] ### 🚀 Contributors 🚀 +- [@bes](https://github.com/bes-dev) shared [great idea and realization with IDWT](https://github.com/bes-dev/vqvae_dwt_distiller.pytorch) for decoding images with higher quality 512x512! 😈💪 - [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference - [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab) - [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt diff --git a/requirements.txt b/requirements.txt index 43ab492..f734f7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ transformers~=4.10.2 youtokentome~=1.0.6 omegaconf>=2.0.0 einops~=0.3.2 +PyWavelets==1.1.1 torch torchvision matplotlib diff --git a/rudalle/__init__.py b/rudalle/__init__.py index 0be85df..47af0e6 100644 --- a/rudalle/__init__.py +++ b/rudalle/__init__.py @@ -22,4 +22,4 @@ 'image_prompts', ] -__version__ = '0.0.1-rc6' +__version__ = '0.0.1-rc7' diff --git a/rudalle/vae/__init__.py b/rudalle/vae/__init__.py index fdda089..994997d 100644 --- a/rudalle/vae/__init__.py +++ b/rudalle/vae/__init__.py @@ -8,17 +8,23 @@ from .model import VQGanGumbelVAE -def get_vae(pretrained=True, cache_dir='/tmp/rudalle'): +def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'): # TODO config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) - vae = VQGanGumbelVAE(config) + vae = VQGanGumbelVAE(config, dwt=dwt) if pretrained: repo_id = 'shonenkov/rudalle-utils' - filename = 'vqgan.gumbelf8-sber.model.ckpt' + if dwt: + filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt' + else: + filename = 'vqgan.gumbelf8-sber.model.ckpt' cache_dir = join(cache_dir, 'vae') config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') - vae.model.load_state_dict(checkpoint['state_dict'], strict=False) + if dwt: + vae.load_state_dict(checkpoint['state_dict']) + else: + vae.model.load_state_dict(checkpoint['state_dict'], strict=False) print('vae --> ready') return vae diff --git a/rudalle/vae/decoder_dwt.py b/rudalle/vae/decoder_dwt.py new file mode 100644 index 0000000..7b97ce7 --- /dev/null +++ b/rudalle/vae/decoder_dwt.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +import pywt +import torch +import torch.nn as nn +from taming.modules.diffusionmodules.model import Decoder + +from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int + + +class DecoderDWT(nn.Module): + def __init__(self, ddconfig, embed_dim): + super().__init__() + if ddconfig.out_ch != 12: + ddconfig.out_ch = 12 + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) + self.decoder = Decoder(**ddconfig) + self.idwt = DWTInverse(mode='zero', wave='db1') + + def forward(self, x): + # x = self.post_quant_conv(x) + freq = self.decoder(x) + img = self.dwt_to_img(freq) + return img + + def dwt_to_img(self, img): + b, c, h, w = img.size() + low = img[:, :3, :, :] + high = img[:, 3:, :, :].view(b, 3, 3, h, w) + return self.idwt((low, [high])) + + +class DWTInverse(nn.Module): + """ Performs a 2d DWT Inverse reconstruction of an image + + Args: + wave (str or pywt.Wavelet): Which wavelet to use + C: deprecated, will be removed in future + """ + + def __init__(self, wave='db1', mode='zero', trace_model=False): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + g0_col, g1_col = wave.rec_lo, wave.rec_hi + g0_row, g1_row = g0_col, g1_col + else: + if len(wave) == 2: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = g0_col, g1_col + elif len(wave) == 4: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = wave[2], wave[3] + # Prepare the filters + filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) + self.register_buffer('g0_col', filts[0]) + self.register_buffer('g1_col', filts[1]) + self.register_buffer('g0_row', filts[2]) + self.register_buffer('g1_row', filts[3]) + self.mode = mode + self.trace_model = trace_model + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', + W_{in}')` and yh is a list of bandpass tensors of shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match + the format returned by DWTForward + + Returns: + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + + Note: + Can have None for any of the highpass scales and will treat the + values as zeros (not in an efficient way though). + """ + yl, yh = coeffs + ll = yl + mode = mode_to_int(self.mode) + + # Do a multilevel inverse transform + for h in yh[::-1]: + if h is None: + h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], + ll.shape[-1], device=ll.device) + + # 'Unpad' added dimensions + if ll.shape[-2] > h.shape[-2]: + ll = ll[..., :-1, :] + if ll.shape[-1] > h.shape[-1]: + ll = ll[..., :-1] + if not self.trace_model: + ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) + else: + ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) + return ll diff --git a/rudalle/vae/model.py b/rudalle/vae/model.py index 2c850b4..d1a76ae 100644 --- a/rudalle/vae/model.py +++ b/rudalle/vae/model.py @@ -8,16 +8,19 @@ from einops import rearrange from taming.modules.diffusionmodules.model import Encoder, Decoder +from .decoder_dwt import DecoderDWT + class VQGanGumbelVAE(torch.nn.Module): - def __init__(self, config): + def __init__(self, config, dwt=False): super().__init__() model = GumbelVQ( ddconfig=config.model.params.ddconfig, n_embed=config.model.params.n_embed, embed_dim=config.model.params.embed_dim, kl_weight=config.model.params.kl_weight, + dwt=dwt, ) self.model = model self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2)) @@ -79,11 +82,12 @@ def forward(self, z, temp=None, return_logits=False): class GumbelVQ(nn.Module): - def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8): + def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8): super().__init__() z_channels = ddconfig['z_channels'] + self.dwt = dwt self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) + self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig) self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0) self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) @@ -95,6 +99,9 @@ def encode(self, x): return quant, emb_loss, info def decode(self, quant): - quant = self.post_quant_conv(quant) + if self.dwt: + quant = self.decoder.post_quant_conv(quant) + else: + quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec diff --git a/rudalle/vae/pytorch_wavelets_utils.py b/rudalle/vae/pytorch_wavelets_utils.py new file mode 100644 index 0000000..b39f6d3 --- /dev/null +++ b/rudalle/vae/pytorch_wavelets_utils.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- +""" +Useful utilities for testing the 2-D DTCWT with synthetic images +License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE +Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa +""" + +import pywt +import torch +import numpy as np +import torch.nn.functional as F +from torch.autograd import Function + + +def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): + """ 1D synthesis filter bank of an image tensor + """ + C = lo.shape[1] + d = dim % 4 + # If g0, g1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(g0, torch.Tensor): + g0 = torch.tensor(np.copy(np.array(g0).ravel()), + dtype=torch.float, device=lo.device) + if not isinstance(g1, torch.Tensor): + g1 = torch.tensor(np.copy(np.array(g1).ravel()), + dtype=torch.float, device=lo.device) + L = g0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + N = 2*lo.shape[d] + # If g aren't in the right shape, make them so + if g0.shape != tuple(shape): + g0 = g0.reshape(*shape) + if g1.shape != tuple(shape): + g1 = g1.reshape(*shape) + + s = (2, 1) if d == 2 else (1, 2) + g0 = torch.cat([g0]*C, dim=0) + g1 = torch.cat([g1]*C, dim=0) + if mode == 'per' or mode == 'periodization': + y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \ + F.conv_transpose2d(hi, g1, stride=s, groups=C) + if d == 2: + y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2] + y = y[:, :, :N] + else: + y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2] + y = y[:, :, :, :N] + y = roll(y, 1-L//2, dim=dim) + else: + if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \ + mode == 'periodic': + pad = (L-2, 0) if d == 2 else (0, L-2) + y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \ + F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C) + else: + raise ValueError('Unkown pad type: {}'.format(mode)) + + return y + + +def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode): + mode = int_to_mode(mode) + + lh, hl, hh = torch.unbind(highs, dim=2) + lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) + + return y + + +def roll(x, n, dim, make_even=False): + if n < 0: + n = x.shape[dim] + n + + if make_even and x.shape[dim] % 2 == 1: + end = 1 + else: + end = 0 + + if dim == 0: + return torch.cat((x[-n:], x[:-n+end]), dim=0) + elif dim == 1: + return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1) + elif dim == 2 or dim == -2: + return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2) + elif dim == 3 or dim == -1: + return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3) + + +def int_to_mode(mode): + if mode == 0: + return 'zero' + elif mode == 1: + return 'symmetric' + elif mode == 2: + return 'periodization' + elif mode == 3: + return 'constant' + elif mode == 4: + return 'reflect' + elif mode == 5: + return 'replicate' + elif mode == 6: + return 'periodic' + else: + raise ValueError('Unkown pad type: {}'.format(mode)) + + +def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None): + """ + Prepares the filters to be of the right form for the sfb2d function. In + particular, makes the tensors the right shape. It does not mirror image them + as as sfb2d uses conv2d_transpose which acts like normal convolution. + Inputs: + g0_col (array-like): low pass column filter bank + g1_col (array-like): high pass column filter bank + g0_row (array-like): low pass row filter bank. If none, will assume the + same as column filter + g1_row (array-like): high pass row filter bank. If none, will assume the + same as column filter + device: which device to put the tensors on to + Returns: + (g0_col, g1_col, g0_row, g1_row) + """ + g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device) + if g0_row is None: + g0_row, g1_row = g0_col, g1_col + else: + g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device) + + g0_col = g0_col.reshape((1, 1, -1, 1)) + g1_col = g1_col.reshape((1, 1, -1, 1)) + g0_row = g0_row.reshape((1, 1, 1, -1)) + g1_row = g1_row.reshape((1, 1, 1, -1)) + + return g0_col, g1_col, g0_row, g1_row + + +def prep_filt_sfb1d(g0, g1, device=None): + """ + Prepares the filters to be of the right form for the sfb1d function. In + particular, makes the tensors the right shape. It does not mirror image them + as as sfb2d uses conv2d_transpose which acts like normal convolution. + Inputs: + g0 (array-like): low pass filter bank + g1 (array-like): high pass filter bank + device: which device to put the tensors on to + Returns: + (g0, g1) + """ + g0 = np.array(g0).ravel() + g1 = np.array(g1).ravel() + t = torch.get_default_dtype() + g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1)) + g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1)) + + return g0, g1 + + +def mode_to_int(mode): + if mode == 'zero': + return 0 + elif mode == 'symmetric': + return 1 + elif mode == 'per' or mode == 'periodization': + return 2 + elif mode == 'constant': + return 3 + elif mode == 'reflect': + return 4 + elif mode == 'replicate': + return 5 + elif mode == 'periodic': + return 6 + else: + raise ValueError('Unkown pad type: {}'.format(mode)) + + +def afb1d(x, h0, h1, mode='zero', dim=-1): + """ 1D analysis filter bank (along one dimension only) of an image + Inputs: + x (tensor): 4D input with the last two dimensions the spatial input + h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + mode (str): padding method + dim (int) - dimension of filtering. d=2 is for a vertical filter (called + column filtering but filters across the rows). d=3 is for a + horizontal filter, (called row filtering but filters across the + columns). + Returns: + lohi: lowpass and highpass subbands concatenated along the channel + dimension + """ + C = x.shape[1] + # Convert the dim to positive + d = dim % 4 + s = (2, 1) if d == 2 else (1, 2) + N = x.shape[d] + # If h0, h1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(h0, torch.Tensor): + h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]), + dtype=torch.float, device=x.device) + if not isinstance(h1, torch.Tensor): + h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]), + dtype=torch.float, device=x.device) + L = h0.numel() + L2 = L // 2 + shape = [1, 1, 1, 1] + shape[d] = L + # If h aren't in the right shape, make them so + if h0.shape != tuple(shape): + h0 = h0.reshape(*shape) + if h1.shape != tuple(shape): + h1 = h1.reshape(*shape) + h = torch.cat([h0, h1] * C, dim=0) + + if mode == 'per' or mode == 'periodization': + if x.shape[dim] % 2 == 1: + if d == 2: + x = torch.cat((x, x[:, :, -1:]), dim=2) + else: + x = torch.cat((x, x[:, :, :, -1:]), dim=3) + N += 1 + x = roll(x, -L2, dim=d) + pad = (L-1, 0) if d == 2 else (0, L-1) + lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + N2 = N//2 + if d == 2: + lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2] + lohi = lohi[:, :, :N2] + else: + lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2] + lohi = lohi[:, :, :, :N2] + else: + # Calculate the pad size + outsize = pywt.dwt_coeff_len(N, L, mode=mode) + p = 2 * (outsize - 1) - N + L + if mode == 'zero': + # Sadly, pytorch only allows for same padding before and after, if + # we need to do more padding after for odd length signals, have to + # prepad + if p % 2 == 1: + pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0) + x = F.pad(x, pad) + pad = (p//2, 0) if d == 2 else (0, p//2) + # Calculate the high and lowpass + lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic': + pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) + x = mypad(x, pad=pad, mode=mode) + lohi = F.conv2d(x, h, stride=s, groups=C) + else: + raise ValueError('Unkown pad type: {}'.format(mode)) + + return lohi + + +def mypad(x, pad, mode='constant', value=0): + """ Function to do numpy like padding on tensors. Only works for 2-D + padding. + Inputs: + x (tensor): tensor to pad + pad (tuple): tuple of (left, right, top, bottom) pad sizes + mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or + 'zero'. The padding technique. + """ + if mode == 'symmetric': + # Vertical only + if pad[0] == 0 and pad[1] == 0: + m1, m2 = pad[2], pad[3] + l = x.shape[-2] # noqa + xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) + return x[:, :, xe] + # horizontal only + elif pad[2] == 0 and pad[3] == 0: + m1, m2 = pad[0], pad[1] + l = x.shape[-1] # noqa + xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) + return x[:, :, :, xe] + # Both + else: + m1, m2 = pad[0], pad[1] + l1 = x.shape[-1] + xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5) + m1, m2 = pad[2], pad[3] + l2 = x.shape[-2] + xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5) + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + elif mode == 'periodic': + # Vertical only + if pad[0] == 0 and pad[1] == 0: + xe = np.arange(x.shape[-2]) + xe = np.pad(xe, (pad[2], pad[3]), mode='wrap') + return x[:, :, xe] + # Horizontal only + elif pad[2] == 0 and pad[3] == 0: + xe = np.arange(x.shape[-1]) + xe = np.pad(xe, (pad[0], pad[1]), mode='wrap') + return x[:, :, :, xe] + # Both + else: + xe_col = np.arange(x.shape[-2]) + xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap') + xe_row = np.arange(x.shape[-1]) + xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap') + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + + elif mode == 'constant' or mode == 'reflect' or mode == 'replicate': + return F.pad(x, pad, mode, value) + elif mode == 'zero': + return F.pad(x, pad) + else: + raise ValueError('Unkown pad type: {}'.format(mode)) + + +def reflect(x, minx, maxx): + """Reflect the values in matrix *x* about the scalar values *minx* and + *maxx*. Hence a vector *x* containing a long linearly increasing series is + converted into a waveform which ramps linearly up and down between *minx* + and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + + 0.5), the ramps will have repeated max and min samples. + .. codeauthor:: Rich Wareham , Aug 2013 + .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. + """ + x = np.asanyarray(x) + rng = maxx - minx + rng_by_2 = 2 * rng + mod = np.fmod(x - minx, rng_by_2) + normed_mod = np.where(mod < 0, mod + rng_by_2, mod) + out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx + return np.array(out, dtype=x.dtype) + + +class SFB2D(Function): + """ Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to + :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` + Needs to have the tensors in the right form. Because this function defines + its own backward pass, saves on memory by not having to save the input + tensors. + Inputs: + x (torch.Tensor): Input to decompose + h0_row: row lowpass + h1_row: row highpass + h0_col: col lowpass + h1_col: col highpass + mode (int): use mode_to_int to get the int code here + We encode the mode as an integer rather than a string as gradcheck causes an + error when a string is provided. + Returns: + y: Tensor of shape (N, C*4, H, W) + """ + @staticmethod + def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): + mode = int_to_mode(mode) + ctx.mode = mode + ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col) + + lh, hl, hh = torch.unbind(highs, dim=2) + lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) + return y + + @staticmethod + def backward(ctx, dy): + dlow, dhigh = None, None + if ctx.needs_input_grad[0]: + mode = ctx.mode + g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors + dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3) + dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2) + s = dx.shape + dx = dx.reshape(s[0], -1, 4, s[-2], s[-1]) + dlow = dx[:, :, 0].contiguous() + dhigh = dx[:, :, 1:].contiguous() + return dlow, dhigh, None, None, None, None, None diff --git a/tests/conftest.py b/tests/conftest.py index 96d7a0f..65abdc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,12 @@ def vae(): yield vae +@pytest.fixture(scope='module') +def dwt_vae(): + vae = get_vae(pretrained=False, dwt=True) + yield vae + + @pytest.fixture(scope='module') def yttm_tokenizer(): tokenizer = get_tokenizer() diff --git a/tests/test_vae.py b/tests/test_vae.py index 0955573..43639e8 100644 --- a/tests/test_vae.py +++ b/tests/test_vae.py @@ -25,6 +25,15 @@ def test_reconstruct_vae(vae, sample_image, target_image_size): assert output.shape == (1, 3, target_image_size, target_image_size) +@pytest.mark.parametrize('target_image_size', [256]) +def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size): + img = sample_image.copy() + with torch.no_grad(): + x_vqgan = preprocess(img, target_image_size=target_image_size) + output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model) + assert output.shape == (1, 3, target_image_size*2, target_image_size*2) + + def preprocess(img, target_image_size=256): s = min(img.size) if s < target_image_size: