From 09431e2d82a2360039475c1d8ae43f411ac20b45 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 May 2024 10:21:48 -0700 Subject: [PATCH] remove group norms https://arxiv.org/abs/2312.02696 --- setup.py | 2 +- x_unet/x_unet.py | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 6b6e630..2cf529f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-unet', packages = find_packages(exclude=[]), - version = '0.3.1', + version = '0.4.0', license='MIT', description = 'X-Unet', long_description_content_type = 'text/markdown', diff --git a/x_unet/x_unet.py b/x_unet/x_unet.py index f33fe24..d6cafb0 100644 --- a/x_unet/x_unet.py +++ b/x_unet/x_unet.py @@ -57,6 +57,15 @@ def __init__(self, fn): def forward(self, x): return self.fn(x) + x +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1)) + + def forward(self, x): + return F.normalize(x, dim = 1) * self.scale * self.gamma + class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() @@ -96,7 +105,7 @@ def __init__( conv = nn.Conv3d if not weight_standardize else WeightStandardizedConv3d self.proj = conv(dim, dim_out, **kernel_conv_kwargs(3, 3)) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward(self, x): @@ -109,19 +118,18 @@ def __init__( self, dim, dim_out, - groups = 8, frame_kernel_size = 1, nested_unet_depth = 0, nested_unet_dim = 32, weight_standardize = False ): super().__init__() - self.block1 = Block(dim, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) + self.block1 = Block(dim, dim_out, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) if nested_unet_depth > 0: self.block2 = NestedResidualUnet(dim_out, depth = nested_unet_depth, M = nested_unet_dim, frame_kernel_size = frame_kernel_size, weight_standardize = weight_standardize, add_residual = True) else: - self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) + self.block2 = Block(dim_out, dim_out, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() @@ -317,7 +325,6 @@ def __init__( nested_unet_dim = 32, channels = 3, use_convnext = False, - resnet_groups = 8, consolidate_upsample_fmaps = True, skip_scale = 2 ** -0.5, weight_standardize = False, @@ -344,7 +351,7 @@ def __init__( # resnet or convnext - blocks = partial(ConvNextBlock, frame_kernel_size = frame_kernel_size) if use_convnext else partial(ResnetBlock, groups = resnet_groups, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) + blocks = partial(ConvNextBlock, frame_kernel_size = frame_kernel_size) if use_convnext else partial(ResnetBlock, weight_standardize = weight_standardize, frame_kernel_size = frame_kernel_size) # whether to use nested unet, as in unet squared paper @@ -557,7 +564,6 @@ def __init__( M = 32, frame_kernel_size = 1, add_residual = False, - groups = 4, skip_scale = 2 ** -0.5, weight_standardize = False ): @@ -575,13 +581,13 @@ def __init__( down = nn.Sequential( conv(dim_in, M, (1, 4, 4), stride = (1, 2, 2), padding = (0, 1, 1)), - nn.GroupNorm(groups, M), + RMSNorm(M), nn.SiLU() ) up = nn.Sequential( PixelShuffleUpsample(2 * M, dim_in), - nn.GroupNorm(groups, dim_in), + RMSNorm(dim_in), nn.SiLU() ) @@ -590,7 +596,7 @@ def __init__( self.mid = nn.Sequential( conv(M, M, **kernel_and_same_pad(frame_kernel_size, 3, 3)), - nn.GroupNorm(groups, M), + RMSNorm(M), nn.SiLU() )