Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Automatic Mixed Precision (AMP) #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pytorch_wavelets/dtcwt/transform_funcs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import tensor
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
from pytorch_wavelets.dtcwt.lowlevel import colfilter, rowfilter
from pytorch_wavelets.dtcwt.lowlevel import coldfilt, rowdfilt
from pytorch_wavelets.dtcwt.lowlevel import colifilt, rowifilt, q2c, c2q
Expand Down Expand Up @@ -343,6 +344,7 @@ def inv_j2plus_rot(ll, highr, highi, g0a, g1a, g0b, g1b, g2a, g2b,
class FWD_J1(Function):
""" Differentiable function doing 1 level forward DTCWT """
@staticmethod
@custom_fwd
def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode):
mode = int_to_mode(mode)
ctx.mode = mode
Expand All @@ -358,6 +360,7 @@ def forward(ctx, x, h0, h1, skip_hps, o_dim, ri_dim, mode):
return ll, highs

@staticmethod
@custom_bwd
def backward(ctx, dl, dh):
h0, h1 = ctx.saved_tensors
mode = ctx.mode
Expand All @@ -377,6 +380,7 @@ def backward(ctx, dl, dh):
class FWD_J2PLUS(Function):
""" Differentiable function doing second level forward DTCWT """
@staticmethod
@custom_fwd
def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode):
mode = 'symmetric'
ctx.mode = mode
Expand All @@ -392,6 +396,7 @@ def forward(ctx, x, h0a, h1a, h0b, h1b, skip_hps, o_dim, ri_dim, mode):
return ll, highs

@staticmethod
@custom_bwd
def backward(ctx, dl, dh):
h0a, h1a, h0b, h1b = ctx.saved_tensors
mode = ctx.mode
Expand All @@ -416,6 +421,7 @@ def backward(ctx, dl, dh):
class INV_J1(Function):
""" Differentiable function doing 1 level inverse DTCWT """
@staticmethod
@custom_fwd
def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode):
mode = int_to_mode(mode)
ctx.mode = mode
Expand All @@ -431,6 +437,7 @@ def forward(ctx, lows, highs, g0, g1, o_dim, ri_dim, mode):
return y

@staticmethod
@custom_bwd
def backward(ctx, dy):
g0, g1 = ctx.saved_tensors
dl = None
Expand All @@ -452,6 +459,7 @@ def backward(ctx, dy):
class INV_J2PLUS(Function):
""" Differentiable function doing level 2 onwards inverse DTCWT """
@staticmethod
@custom_fwd
def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode):
mode = 'symmetric'
ctx.mode = mode
Expand All @@ -468,6 +476,7 @@ def forward(ctx, lows, highs, g0a, g1a, g0b, g1b, o_dim, ri_dim, mode):
return y

@staticmethod
@custom_bwd
def backward(ctx, dy):
g0a, g1a, g0b, g1b = ctx.saved_tensors
g0a, g0b = g0b, g0a
Expand Down
9 changes: 9 additions & 0 deletions pytorch_wavelets/dwt/lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
import numpy as np
from torch.autograd import Function
from torch.cuda.amp import custom_fwd, custom_bwd
from pytorch_wavelets.utils import reflect
import pywt

Expand Down Expand Up @@ -333,6 +334,7 @@ class AFB2D(Function):
y: Tensor of shape (N, C*4, H, W)
"""
@staticmethod
@custom_fwd
def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):
ctx.save_for_backward(h0_row, h1_row, h0_col, h1_col)
ctx.shape = x.shape[-2:]
Expand All @@ -347,6 +349,7 @@ def forward(ctx, x, h0_row, h1_row, h0_col, h1_col, mode):
return low, highs

@staticmethod
@custom_bwd
def backward(ctx, low, highs):
dx = None
if ctx.needs_input_grad[0]:
Expand Down Expand Up @@ -386,6 +389,7 @@ class AFB1D(Function):
x1: Tensor of shape (N, C, L') - highpass
"""
@staticmethod
@custom_fwd
def forward(ctx, x, h0, h1, mode):
mode = int_to_mode(mode)

Expand All @@ -405,6 +409,7 @@ def forward(ctx, x, h0, h1, mode):
return x0, x1

@staticmethod
@custom_bwd
def backward(ctx, dx0, dx1):
dx = None
if ctx.needs_input_grad[0]:
Expand Down Expand Up @@ -668,6 +673,7 @@ class SFB2D(Function):
y: Tensor of shape (N, C*4, H, W)
"""
@staticmethod
@custom_fwd
def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
mode = int_to_mode(mode)
ctx.mode = mode
Expand All @@ -680,6 +686,7 @@ def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
return y

@staticmethod
@custom_bwd
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
Expand Down Expand Up @@ -715,6 +722,7 @@ class SFB1D(Function):
y: Tensor of shape (N, C*2, L')
"""
@staticmethod
@custom_fwd
def forward(ctx, low, high, g0, g1, mode):
mode = int_to_mode(mode)
# Make into a 2d tensor with 1 row
Expand All @@ -729,6 +737,7 @@ def forward(ctx, low, high, g0, g1, mode):
return sfb1d(low, high, g0, g1, mode=mode, dim=3)[:, :, 0]

@staticmethod
@custom_bwd
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
Expand Down
11 changes: 11 additions & 0 deletions pytorch_wavelets/scatternet/lowlevel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_fwd, custom_bwd

from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1, inv_j1
from pytorch_wavelets.dtcwt.transform_funcs import fwd_j1_rot, inv_j1_rot
Expand Down Expand Up @@ -49,6 +50,7 @@ def int_to_mode(mode):
class SmoothMagFn(torch.autograd.Function):
""" Class to do complex magnitude """
@staticmethod
@custom_fwd
def forward(ctx, x, y, b):
r = torch.sqrt(x**2 + y**2 + b**2)
if x.requires_grad:
Expand All @@ -59,6 +61,7 @@ def forward(ctx, x, y, b):
return r - b

@staticmethod
@custom_bwd
def backward(ctx, dr):
dx = None
if ctx.needs_input_grad[0]:
Expand All @@ -73,6 +76,7 @@ class ScatLayerj1_f(torch.autograd.Function):
layer with the DTCWT biorthogonal filters. """

@staticmethod
@custom_fwd
def forward(ctx, x, h0o, h1o, mode, bias, combine_colour):
# bias = 1e-2
# bias = 0
Expand Down Expand Up @@ -111,6 +115,7 @@ def forward(ctx, x, h0o, h1o, mode, bias, combine_colour):
return Z

@staticmethod
@custom_bwd
def backward(ctx, dZ):
dX = None
mode = ctx.mode
Expand Down Expand Up @@ -143,6 +148,7 @@ class ScatLayerj1_rot_f(torch.autograd.Function):
filters, i.e. a slightly more expensive operation."""

@staticmethod
@custom_fwd
def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour):
mode = int_to_mode(mode)
ctx.mode = mode
Expand Down Expand Up @@ -179,6 +185,7 @@ def forward(ctx, x, h0o, h1o, h2o, mode, bias, combine_colour):
return Z

@staticmethod
@custom_bwd
def backward(ctx, dZ):
dX = None
mode = ctx.mode
Expand Down Expand Up @@ -208,6 +215,7 @@ class ScatLayerj2_f(torch.autograd.Function):
layer with the DTCWT biorthogonal filters. """

@staticmethod
@custom_fwd
def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour):
# bias = 1e-2
# bias = 0
Expand Down Expand Up @@ -309,6 +317,7 @@ def forward(ctx, x, h0o, h1o, h0a, h0b, h1a, h1b, mode, bias, combine_colour):
return Z

@staticmethod
@custom_bwd
def backward(ctx, dZ):
dX = None
mode = ctx.mode
Expand Down Expand Up @@ -403,6 +412,7 @@ class ScatLayerj2_rot_f(torch.autograd.Function):
layer with the DTCWT bandpass biorthogonal and qshift filters . """

@staticmethod
@custom_fwd
def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, combine_colour):
# bias = 1e-2
# bias = 0
Expand Down Expand Up @@ -502,6 +512,7 @@ def forward(ctx, x, h0o, h1o, h2o, h0a, h0b, h1a, h1b, h2a, h2b, mode, bias, com
return Z

@staticmethod
@custom_bwd
def backward(ctx, dZ):
dX = None
mode = ctx.mode
Expand Down