-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from sberbank-ai/feature/dwt_vae
Feature/dwt vae
- Loading branch information
Showing
10 changed files
with
533 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[run] | ||
omit = | ||
# omit this single file | ||
rudalle/vae/pytorch_wavelets_utils.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,4 @@ | |
'image_prompts', | ||
] | ||
|
||
__version__ = '0.0.1-rc6' | ||
__version__ = '0.0.1-rc7' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# -*- coding: utf-8 -*- | ||
import pywt | ||
import torch | ||
import torch.nn as nn | ||
from taming.modules.diffusionmodules.model import Decoder | ||
|
||
from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int | ||
|
||
|
||
class DecoderDWT(nn.Module): | ||
def __init__(self, ddconfig, embed_dim): | ||
super().__init__() | ||
if ddconfig.out_ch != 12: | ||
ddconfig.out_ch = 12 | ||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) | ||
self.decoder = Decoder(**ddconfig) | ||
self.idwt = DWTInverse(mode='zero', wave='db1') | ||
|
||
def forward(self, x): | ||
# x = self.post_quant_conv(x) | ||
freq = self.decoder(x) | ||
img = self.dwt_to_img(freq) | ||
return img | ||
|
||
def dwt_to_img(self, img): | ||
b, c, h, w = img.size() | ||
low = img[:, :3, :, :] | ||
high = img[:, 3:, :, :].view(b, 3, 3, h, w) | ||
return self.idwt((low, [high])) | ||
|
||
|
||
class DWTInverse(nn.Module): | ||
""" Performs a 2d DWT Inverse reconstruction of an image | ||
Args: | ||
wave (str or pywt.Wavelet): Which wavelet to use | ||
C: deprecated, will be removed in future | ||
""" | ||
|
||
def __init__(self, wave='db1', mode='zero', trace_model=False): | ||
super().__init__() | ||
if isinstance(wave, str): | ||
wave = pywt.Wavelet(wave) | ||
if isinstance(wave, pywt.Wavelet): | ||
g0_col, g1_col = wave.rec_lo, wave.rec_hi | ||
g0_row, g1_row = g0_col, g1_col | ||
else: | ||
if len(wave) == 2: | ||
g0_col, g1_col = wave[0], wave[1] | ||
g0_row, g1_row = g0_col, g1_col | ||
elif len(wave) == 4: | ||
g0_col, g1_col = wave[0], wave[1] | ||
g0_row, g1_row = wave[2], wave[3] | ||
# Prepare the filters | ||
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) | ||
self.register_buffer('g0_col', filts[0]) | ||
self.register_buffer('g1_col', filts[1]) | ||
self.register_buffer('g0_row', filts[2]) | ||
self.register_buffer('g1_row', filts[3]) | ||
self.mode = mode | ||
self.trace_model = trace_model | ||
|
||
def forward(self, coeffs): | ||
""" | ||
Args: | ||
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: | ||
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', | ||
W_{in}')` and yh is a list of bandpass tensors of shape | ||
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match | ||
the format returned by DWTForward | ||
Returns: | ||
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` | ||
Note: | ||
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly | ||
downsampled shapes of the DWT pyramid. | ||
Note: | ||
Can have None for any of the highpass scales and will treat the | ||
values as zeros (not in an efficient way though). | ||
""" | ||
yl, yh = coeffs | ||
ll = yl | ||
mode = mode_to_int(self.mode) | ||
|
||
# Do a multilevel inverse transform | ||
for h in yh[::-1]: | ||
if h is None: | ||
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], | ||
ll.shape[-1], device=ll.device) | ||
|
||
# 'Unpad' added dimensions | ||
if ll.shape[-2] > h.shape[-2]: | ||
ll = ll[..., :-1, :] | ||
if ll.shape[-1] > h.shape[-1]: | ||
ll = ll[..., :-1] | ||
if not self.trace_model: | ||
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) | ||
else: | ||
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) | ||
return ll |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.