diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py index 6c88bfa90..621f42d33 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py @@ -65,33 +65,34 @@ def __call__(self, x, train=True): if dropout_rate is None: dropout_rate = 0.0 - PreconfiguredConvBlock = functools.partial( + # pylint: disable=C0103 + _ConvBlock = functools.partial( ConvBlock, dropout_rate=dropout_rate, use_tanh=self.use_tanh, use_layer_norm=self.use_layer_norm) - PreconfiguredTransposeConvBlock = functools.partial( + _TransposeConvBlock = functools.partial( TransposeConvBlock, use_tanh=self.use_tanh, use_layer_norm=self.use_layer_norm) - down_sample_layers = [PreconfiguredConvBlock(self.num_channels)] + down_sample_layers = [_ConvBlock(self.num_channels)] ch = self.num_channels for _ in range(self.num_pool_layers - 1): - down_sample_layers.append(PreconfiguredConvBlock(ch * 2)) + down_sample_layers.append(_ConvBlock(ch * 2)) ch *= 2 - conv = PreconfiguredConvBlock(ch * 2) + conv = _ConvBlock(ch * 2) up_conv = [] up_transpose_conv = [] for _ in range(self.num_pool_layers - 1): - up_transpose_conv.append(PreconfiguredTransposeConvBlock(ch)) - up_conv.append(PreconfiguredConvBlock(ch)) + up_transpose_conv.append(_TransposeConvBlock(ch)) + up_conv.append(_ConvBlock(ch)) ch //= 2 - up_transpose_conv.append(PreconfiguredTransposeConvBlock(ch)) - up_conv.append(PreconfiguredConvBlock(ch)) + up_transpose_conv.append(_TransposeConvBlock(ch)) + up_conv.append(_ConvBlock(ch)) stack = [] output = jnp.expand_dims(x, axis=-1)