Skip to content

Commit

Permalink
add the residual unet from unet² paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2022
1 parent e481ae8 commit 7f811fa
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 5 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,22 @@ Implementation of a U-net complete with efficient attention as well as the lates

## Citations

tbd
```bibtex
@article{Ronneberger2015UNetCN,
title = {U-Net: Convolutional Networks for Biomedical Image Segmentation},
author = {Olaf Ronneberger and Philipp Fischer and Thomas Brox},
journal = {ArXiv},
year = {2015},
volume = {abs/1505.04597}
}
```

```bibtex
@article{Qin2020U2NetGD,
title = {U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection},
author = {Xuebin Qin and Zichen Vincent Zhang and Chenyang Huang and Masood Dehghan and Osmar R Zaiane and Martin J{\"a}gersand},
journal = {ArXiv},
year = {2020},
volume = {abs/2005.09007}
}
```
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-unet',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'X-Unet',
long_description_content_type = 'text/markdown',
Expand All @@ -13,7 +13,9 @@
keywords = [
'artificial intelligence',
'deep learning',
'unets'
'biomedical segmentation',
'medical deep learning',
'unets',
],
install_requires=[
'einops>=0.4',
Expand Down
2 changes: 1 addition & 1 deletion x_unet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from x_unet.x_unet import XUnet
from x_unet.x_unet import XUnet, NestedResidualUnet
107 changes: 106 additions & 1 deletion x_unet/x_unet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import partial

import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat

# helper functions

Expand All @@ -13,11 +14,40 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def is_power_two(n):
return math.log2(n).is_integer()

# helper classes

def Upsample(dim):
return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = nn.Conv2d(dim, dim_out * 4, 1)

self.net = nn.Sequential(
conv,
nn.SiLU(),
nn.PixelShuffle(2)
)

self.init_conv_(conv)

def init_conv_(self, conv):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')

conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)

def forward(self, x):
return self.net(x)

def Downsample(dim):
return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1))

Expand Down Expand Up @@ -149,3 +179,78 @@ def forward(self, x):
x = rearrange(x, 'b c 1 h w -> b c h w')

return x

# RSU

class NestedResidualUnet(nn.Module):
def __init__(
self,
dim,
*,
depth,
M = 32,
add_residual = False,
groups = 4
):
super().__init__()

self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])

for ind in range(depth):
is_first = ind == 0
dim_in = dim if is_first else M

down = nn.Sequential(
nn.Conv2d(dim_in, M, 4, stride = 2, padding = 1),
nn.GroupNorm(groups, M),
nn.SiLU()
)

up = nn.Sequential(
PixelShuffleUpsample(2 * M, dim_in),
nn.GroupNorm(groups, dim_in),
nn.SiLU()
)

self.downs.append(down)
self.ups.append(up)

self.mid = nn.Sequential(
nn.Conv2d(M, M, 3, padding = 1),
nn.GroupNorm(groups, M),
nn.SiLU()
)

self.add_residual = add_residual

def forward(self, x):
*_, h, w = x.shape
assert is_power_two(h) and is_power_two(w)

if self.add_residual:
residual = x.clone()

# hiddens

hiddens = []

# unet

for down in self.downs:
x = down(x)
hiddens.append(x.clone().contiguous())

x = self.mid(x)

for up in reversed(self.ups):
x = torch.cat((x, hiddens.pop()), dim = 1)
x = up(x)

# adding residual

if self.add_residual:
x = x + residual
x = F.silu(x)

return x

0 comments on commit 7f811fa

Please sign in to comment.