diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py index ed9da7185..38bf73892 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py @@ -184,7 +184,7 @@ def __init__(self, nn.ConvTranspose2d( in_chans, out_chans, kernel_size=2, stride=2, bias=False), nn.GroupNorm(num_groups=1, num_channels=out_chans, eps=1e-6), - norm_layer, + norm_layer(out_chans), activation_fn, )