Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 3, 2024
1 parent a8d1582 commit 09431e2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-unet',
packages = find_packages(exclude=[]),
version = '0.3.1',
version = '0.4.0',
license='MIT',
description = 'X-Unet',
long_description_content_type = 'text/markdown',
Expand Down
26 changes: 16 additions & 10 deletions x_unet/x_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def __init__(self, fn):
def forward(self, x):
return self.fn(x) + x

class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1))

def forward(self, x):
return F.normalize(x, dim = 1) * self.scale * self.gamma

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -96,7 +105,7 @@ def __init__(
conv = nn.Conv3d if not weight_standardize else WeightStandardizedConv3d

self.proj = conv(dim, dim_out, **kernel_conv_kwargs(3, 3))
self.norm = nn.GroupNorm(groups, dim_out)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()

def forward(self, x):
Expand All @@ -109,19 +118,18 @@ def __init__(
self,
dim,
dim_out,
groups = 8,
frame_kernel_size = 1,
nested_unet_depth = 0,
nested_unet_dim = 32,
weight_standardize = False
):
super().__init__()
self.block1 = Block(dim, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)
self.block1 = Block(dim, dim_out, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

if nested_unet_depth > 0:
self.block2 = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, frame_kernel_size = frame_kernel_size, weight_standardize = weight_standardize, add_residual = True)
else:
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)
self.block2 = Block(dim_out, dim_out, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

Expand Down Expand Up @@ -317,7 +325,6 @@ def __init__(
nested_unet_dim = 32,
channels = 3,
use_convnext = False,
resnet_groups = 8,
consolidate_upsample_fmaps = True,
skip_scale = 2 ** -0.5,
weight_standardize = False,
Expand All @@ -344,7 +351,7 @@ def __init__(

# resnet or convnext

blocks = partial(ConvNextBlock, frame_kernel_size = frame_kernel_size) if use_convnext else partial(ResnetBlock, groups = resnet_groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)
blocks = partial(ConvNextBlock, frame_kernel_size = frame_kernel_size) if use_convnext else partial(ResnetBlock, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size)

# whether to use nested unet, as in unet squared paper

Expand Down Expand Up @@ -557,7 +564,6 @@ def __init__(
M = 32,
frame_kernel_size = 1,
add_residual = False,
groups = 4,
skip_scale = 2 ** -0.5,
weight_standardize = False
):
Expand All @@ -575,13 +581,13 @@ def __init__(

down = nn.Sequential(
conv(dim_in, M, (1, 4, 4), stride = (1, 2, 2), padding = (0, 1, 1)),
nn.GroupNorm(groups, M),
RMSNorm(M),
nn.SiLU()
)

up = nn.Sequential(
PixelShuffleUpsample(2 * M, dim_in),
nn.GroupNorm(groups, dim_in),
RMSNorm(dim_in),
nn.SiLU()
)

Expand All @@ -590,7 +596,7 @@ def __init__(

self.mid = nn.Sequential(
conv(M, M, **kernel_and_same_pad(frame_kernel_size, 3, 3)),
nn.GroupNorm(groups, M),
RMSNorm(M),
nn.SiLU()
)

Expand Down

0 comments on commit 09431e2

Please sign in to comment.