Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 4, 2024
1 parent aa20c9d commit 7ace6c0
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 230 deletions.
2 changes: 1 addition & 1 deletion generative/smrd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ SURE-Based Robust MRI Reconstruction with Diffusion Models (SMRD). MICCAI 2023 (

![SMRD](figures/SMRD.png)
It show cases how the conjugate gradient method can be used to enforece meausrement consistency in diffusion model based MRI reconstruction; it also shows
how the SURE-based method can be used to perform early stopping, so less iteratios and artifacts are introduced during the generation of the reconstructured image.
how the SURE-based method can be used to perform early stopping, so less iteratios and artifacts are introduced during the generation of the reconstructured image.
178 changes: 94 additions & 84 deletions generative/smrd/SMRD.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions generative/smrd/configs/demo/SMRD-brain_T2-noise005-R8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ lambda_lr: 0.2
init_lambda_update: 1154
last_lambda_update: 1655

## Lambda
## Lambda
lambda_init: 2.0
lambda_end: 2.0
lambda_func: learnable
Expand Down Expand Up @@ -131,4 +131,3 @@ langevin_config:
beta1: 0.9
amsgrad: false
eps: 0.001

2 changes: 1 addition & 1 deletion generative/smrd/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
5 changes: 3 additions & 2 deletions generative/smrd/models/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch.nn as nn


class EMAHelper(object):
def __init__(self, mu=0.999):
self.mu = mu
Expand All @@ -28,7 +29,7 @@ def update(self, module):
module = module.module
for name, param in module.named_parameters():
if param.requires_grad:
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data
self.shadow[name].data = (1.0 - self.mu) * param.data + self.mu * self.shadow[name].data

def ema(self, module):
if isinstance(module, nn.DataParallel):
Expand All @@ -54,4 +55,4 @@ def state_dict(self):
return self.shadow

def load_state_dict(self, state_dict):
self.shadow = state_dict
self.shadow = state_dict
120 changes: 72 additions & 48 deletions generative/smrd/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,45 @@


def get_act(config):
if config.model.nonlinearity.lower() == 'elu':
if config.model.nonlinearity.lower() == "elu":
return nn.ELU()
elif config.model.nonlinearity.lower() == 'relu':
elif config.model.nonlinearity.lower() == "relu":
return nn.ReLU()
elif config.model.nonlinearity.lower() == 'lrelu':
elif config.model.nonlinearity.lower() == "lrelu":
return nn.LeakyReLU(negative_slope=0.2)
elif config.model.nonlinearity.lower() == 'swish':
elif config.model.nonlinearity.lower() == "swish":

def swish(x):
return x * torch.sigmoid(x)

return swish
else:
raise NotImplementedError('activation function does not exist!')
raise NotImplementedError("activation function does not exist!")


def spectral_norm(layer, n_iters=1):
return torch.nn.utils.spectral_norm(layer, n_power_iterations=n_iters)


def conv1x1(in_planes, out_planes, stride=1, bias=True, spec_norm=False):
"1x1 convolution"
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=bias)
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv


def conv3x3(in_planes, out_planes, stride=1, bias=True, spec_norm=False):
"3x3 convolution with padding"
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=bias)
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias)
if spec_norm:
conv = spectral_norm(conv)

return conv


def stride_conv3x3(in_planes, out_planes, kernel_size, bias=True, spec_norm=False):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=kernel_size // 2, bias=bias)
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=kernel_size // 2, bias=bias)
if spec_norm:
conv = spectral_norm(conv)
return conv
Expand All @@ -67,6 +68,7 @@ def dilated_conv3x3(in_planes, out_planes, dilation, bias=True, spec_norm=False)

return conv


class CRPBlock(nn.Module):
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True, spec_norm=False):
super().__init__()
Expand Down Expand Up @@ -123,8 +125,11 @@ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU(), spec_norm=False)

for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), conv3x3(features, features, stride=1, bias=False,
spec_norm=spec_norm))
setattr(
self,
"{}_{}_conv".format(i + 1, j + 1),
conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm),
)

self.stride = 1
self.n_blocks = n_blocks
Expand All @@ -136,7 +141,7 @@ def forward(self, x):
residual = x
for j in range(self.n_stages):
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x)

x += residual
return x
Expand All @@ -148,9 +153,12 @@ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn

for i in range(n_blocks):
for j in range(n_stages):
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
setattr(self, '{}_{}_conv'.format(i + 1, j + 1),
conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm))
setattr(self, "{}_{}_norm".format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
setattr(
self,
"{}_{}_conv".format(i + 1, j + 1),
conv3x3(features, features, stride=1, bias=False, spec_norm=spec_norm),
)

self.stride = 1
self.n_blocks = n_blocks
Expand All @@ -162,9 +170,9 @@ def forward(self, x, y):
for i in range(self.n_blocks):
residual = x
for j in range(self.n_stages):
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
x = getattr(self, "{}_{}_norm".format(i + 1, j + 1))(x, y)
x = self.act(x)
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
x = getattr(self, "{}_{}_conv".format(i + 1, j + 1))(x)

x += residual
return x
Expand All @@ -187,7 +195,7 @@ def forward(self, xs, shape):
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
for i in range(len(self.convs)):
h = self.convs[i](xs[i])
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True)
sums += h
return sums

Expand All @@ -214,7 +222,7 @@ def forward(self, xs, y, shape):
for i in range(len(self.convs)):
h = self.norms[i](xs[i], y)
h = self.convs[i](h)
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
h = F.interpolate(h, size=shape, mode="bilinear", align_corners=True)
sums += h
return sums

Expand All @@ -228,9 +236,7 @@ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, m

self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
self.adapt_convs.append(
RCUBlock(in_planes[i], 2, 2, act, spec_norm=spec_norm)
)
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act, spec_norm=spec_norm))

self.output_convs = RCUBlock(features, 3 if end else 1, 2, act, spec_norm=spec_norm)

Expand All @@ -257,21 +263,22 @@ def forward(self, xs, output_shape):
return h



class CondRefineBlock(nn.Module):
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False, spec_norm=False):
def __init__(
self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False, spec_norm=False
):
super().__init__()

assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
self.n_blocks = n_blocks = len(in_planes)

self.adapt_convs = nn.ModuleList()
for i in range(n_blocks):
self.adapt_convs.append(
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act, spec_norm=spec_norm)
)
self.adapt_convs.append(CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act, spec_norm=spec_norm))

self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act, spec_norm=spec_norm)
self.output_convs = CondRCUBlock(
features, 3 if end else 1, 2, num_classes, normalizer, act, spec_norm=spec_norm
)

if not start:
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer, spec_norm=spec_norm)
Expand Down Expand Up @@ -309,17 +316,15 @@ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_pad
if spec_norm:
conv = spectral_norm(conv)

self.conv = nn.Sequential(
nn.ZeroPad2d((1, 0, 1, 0)),
conv
)
self.avg = nn.AvgPool2d( kernel_size=2 )
self.conv = nn.Sequential(nn.ZeroPad2d((1, 0, 1, 0)), conv)
self.avg = nn.AvgPool2d(kernel_size=2)

def forward(self, inputs):
output = self.conv(inputs)
output = self.avg( output )
output = self.avg(output)
return output


class MeanPoolConv(nn.Module):
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=False):
super().__init__()
Expand All @@ -329,8 +334,10 @@ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, spec_norm=

def forward(self, inputs):
output = inputs
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
output = (
sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]])
/ 4.0
)
return self.conv(output)


Expand All @@ -350,15 +357,25 @@ def forward(self, inputs):


class ConditionalResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, num_classes, resample=None, act=nn.ELU(),
normalization=ConditionalBatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False):
def __init__(
self,
input_dim,
output_dim,
num_classes,
resample=None,
act=nn.ELU(),
normalization=ConditionalBatchNorm2d,
adjust_padding=False,
dilation=None,
spec_norm=False,
):
super().__init__()
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if resample == "down":
if dilation is not None:
self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim, num_classes)
Expand All @@ -382,14 +399,13 @@ def __init__(self, input_dim, output_dim, num_classes, resample=None, act=nn.ELU
self.normalize2 = normalization(output_dim, num_classes)
self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm)
else:
raise Exception('invalid resample value')
raise Exception("invalid resample value")

if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)

self.normalize1 = normalization(input_dim, num_classes)


def forward(self, x, y):
output = self.normalize1(x, y)
output = self.non_linearity(output)
Expand All @@ -407,15 +423,24 @@ def forward(self, x, y):


class ResidualBlock(nn.Module):
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
normalization=nn.BatchNorm2d, adjust_padding=False, dilation=None, spec_norm=False):
def __init__(
self,
input_dim,
output_dim,
resample=None,
act=nn.ELU(),
normalization=nn.BatchNorm2d,
adjust_padding=False,
dilation=None,
spec_norm=False,
):
super().__init__()
self.non_linearity = act
self.input_dim = input_dim
self.output_dim = output_dim
self.resample = resample
self.normalization = normalization
if resample == 'down':
if resample == "down":
if dilation is not None:
self.conv1 = dilated_conv3x3(input_dim, input_dim, dilation=dilation, spec_norm=spec_norm)
self.normalize2 = normalization(input_dim)
Expand All @@ -440,14 +465,13 @@ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
self.normalize2 = normalization(output_dim)
self.conv2 = conv3x3(output_dim, output_dim, spec_norm=spec_norm)
else:
raise Exception('invalid resample value')
raise Exception("invalid resample value")

if output_dim != input_dim or resample is not None:
self.shortcut = conv_shortcut(input_dim, output_dim)

self.normalize1 = normalization(input_dim)


def forward(self, x):
output = self.normalize1(x)
output = self.non_linearity(output)
Expand Down
Loading

0 comments on commit 7ace6c0

Please sign in to comment.