From 1b9cc03801a9e005e0aeccd3b0f493c70b7fcf83 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:52:23 +0100 Subject: [PATCH 1/7] re-facored for 1D, other dimensions to follow --- modulus/models/fno/__init__.py | 2 +- modulus/models/fno/fno.py | 152 ++++++++++++++++++--------------- 2 files changed, 85 insertions(+), 69 deletions(-) diff --git a/modulus/models/fno/__init__.py b/modulus/models/fno/__init__.py index 0c272b5eb8..adccb972b1 100644 --- a/modulus/models/fno/__init__.py +++ b/modulus/models/fno/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fno import FNO +from .fno import FNO, FNO1DEncoder diff --git a/modulus/models/fno/fno.py b/modulus/models/fno/fno.py index 393c91cf4c..093f04773e 100644 --- a/modulus/models/fno/fno.py +++ b/modulus/models/fno/fno.py @@ -57,6 +57,7 @@ class FNO1DEncoder(nn.Module): Use coordinate grid as additional feature map, by default True """ + def __init__( self, in_channels: int = 1, @@ -70,22 +71,47 @@ def __init__( ) -> None: super().__init__() + self.set_vars(in_channels, + num_fno_layers, + fno_layer_size, + padding, + padding_type, + activation_fn, + coord_features) + self.build_lift_network() + self.build_fno(num_fno_modes) + + + def set_vars(self, + in_channels, + num_fno_layers, + fno_layer_size, + padding, + padding_type, + activation_fn, + coord_features): self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size - self.coord_features = coord_features - # Spectral modes to have weights - if isinstance(num_fno_modes, int): - num_fno_modes = [num_fno_modes] + self.activation_fn = activation_fn + # Add relative coordinate feature + self.coord_features = coord_features if self.coord_features: self.in_channels = self.in_channels + 1 - self.activation_fn = activation_fn + + # Padding values for spectral conv + if isinstance(padding, int): + padding = [padding] + self.pad = padding[:1] + self.ipad = [-pad if pad > 0 else None for pad in self.pad] + self.padding_type = padding_type self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() - # Initial lift network + + def build_lift_network(self): self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2)) @@ -95,6 +121,11 @@ def __init__( layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width) ) + + def build_fno(self, num_fno_modes): + if isinstance(num_fno_modes, int): + num_fno_modes = [num_fno_modes] + # Build Neural Fourier Operators for _ in range(self.num_fno_layers): self.spconv_layers.append( @@ -102,12 +133,6 @@ def __init__( ) self.conv_layers.append(nn.Conv1d(self.fno_width, self.fno_width, 1)) - # Padding values for spectral conv - if isinstance(padding, int): - padding = [padding] - self.pad = padding[:1] - self.ipad = [-pad if pad > 0 else None for pad in self.pad] - self.padding_type = padding_type def forward(self, x: Tensor) -> Tensor: if self.coord_features: @@ -148,6 +173,16 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: grid_x = grid_x.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1) return grid_x + def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + y_shape = list(value.size()) + output = torch.permute(value, (0, 2, 1)) + return output.reshape(-1, output.size(-1)), y_shape + + + def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + output = value.reshape(shape[0], shape[2], value.size(-1)) + return torch.permute(output, (0, 2, 1)) + # =================================================================== # =================================================================== @@ -282,6 +317,16 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: grid_y = grid_y.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1) return torch.cat((grid_x, grid_y), dim=1) + def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + y_shape = list(value.size()) + output = torch.permute(value, (0, 2, 3, 1)) + return output.reshape(-1, output.size(-1)), y_shape + + + def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + output = value.reshape(shape[0], shape[2], shape[3], value.size(-1)) + return torch.permute(output, (0, 3, 1, 2)) + # =================================================================== # =================================================================== @@ -421,6 +466,17 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: return torch.cat((grid_x, grid_y, grid_z), dim=1) + def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + y_shape = list(value.size()) + output = torch.permute(value, (0, 2, 3, 4, 1)) + return output.reshape(-1, output.size(-1)), y_shape + + + def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1)) + return torch.permute(output, (0, 4, 1, 2, 3)) + + # =================================================================== # =================================================================== # 4D FNO @@ -571,52 +627,23 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: grid_t = grid_t.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1, 1) return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1) - -# Functions for converting between point based and grid (image) representations -def _grid_to_points1d(value: Tensor) -> Tuple[Tensor, List[int]]: - y_shape = list(value.size()) - output = torch.permute(value, (0, 2, 1)) - return output.reshape(-1, output.size(-1)), y_shape - - -def _points_to_grid1d(value: Tensor, shape: List[int]) -> Tensor: - output = value.reshape(shape[0], shape[2], value.size(-1)) - return torch.permute(output, (0, 2, 1)) - - -def _grid_to_points2d(value: Tensor) -> Tuple[Tensor, List[int]]: - y_shape = list(value.size()) - output = torch.permute(value, (0, 2, 3, 1)) - return output.reshape(-1, output.size(-1)), y_shape + def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + y_shape = list(value.size()) + output = torch.permute(value, (0, 2, 3, 4, 5, 1)) + return output.reshape(-1, output.size(-1)), y_shape -def _points_to_grid2d(value: Tensor, shape: List[int]) -> Tensor: - output = value.reshape(shape[0], shape[2], shape[3], value.size(-1)) - return torch.permute(output, (0, 3, 1, 2)) - - -def _grid_to_points3d(value: Tensor) -> Tuple[Tensor, List[int]]: - y_shape = list(value.size()) - output = torch.permute(value, (0, 2, 3, 4, 1)) - return output.reshape(-1, output.size(-1)), y_shape - + def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + output = value.reshape( + shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1) + ) + return torch.permute(output, (0, 5, 1, 2, 3, 4)) -def _points_to_grid3d(value: Tensor, shape: List[int]) -> Tensor: - output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1)) - return torch.permute(output, (0, 4, 1, 2, 3)) +# Functions for converting between point based and grid (image) representations -def _grid_to_points4d(value: Tensor) -> Tuple[Tensor, List[int]]: - y_shape = list(value.size()) - output = torch.permute(value, (0, 2, 3, 4, 5, 1)) - return output.reshape(-1, output.size(-1)), y_shape -def _points_to_grid4d(value: Tensor, shape: List[int]) -> Tensor: - output = value.reshape( - shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1) - ) - return torch.permute(output, (0, 5, 1, 2, 3, 4)) # =================================================================== @@ -739,21 +766,7 @@ def __init__( ) if dimension == 1: - FNOModel = FNO1DEncoder - self.grid_to_points = _grid_to_points1d # For JIT - self.points_to_grid = _points_to_grid1d # For JIT - elif dimension == 2: - FNOModel = FNO2DEncoder - self.grid_to_points = _grid_to_points2d # For JIT - self.points_to_grid = _points_to_grid2d # For JIT - elif dimension == 3: - FNOModel = FNO3DEncoder - self.grid_to_points = _grid_to_points3d # For JIT - self.points_to_grid = _points_to_grid3d # For JIT - elif dimension == 4: - FNOModel = FNO4DEncoder - self.grid_to_points = _grid_to_points4d # For JIT - self.points_to_grid = _points_to_grid4d # For JIT + FNOModel = self.get_FNO1DEncoder() else: raise NotImplementedError( "Invalid dimensionality. Only 1D, 2D, 3D and 4D FNO implemented" @@ -770,18 +783,21 @@ def __init__( coord_features=self.coord_features, ) + def get_FNO1DEncoder(self): # put all dims in here + return FNO1DEncoder + def forward(self, x: Tensor) -> Tensor: # Fourier encoder y_latent = self.spec_encoder(x) # Reshape to pointwise inputs if not a conv FC model y_shape = y_latent.shape - y_latent, y_shape = self.grid_to_points(y_latent) + y_latent, y_shape = self.spec_encoder.grid_to_points(y_latent) # Decoder y = self.decoder_net(y_latent) # Convert back into grid - y = self.points_to_grid(y, y_shape) + y = self.spec_encoder.points_to_grid(y, y_shape) return y From f8ad4f647c2e8f014ffc7211602ba0e70935c2b1 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Thu, 30 Nov 2023 13:58:15 +0100 Subject: [PATCH 2/7] made 2D-4D FNOs modular as well --- modulus/models/fno/__init__.py | 2 +- modulus/models/fno/fno.py | 171 ++++++++++++++++----------------- 2 files changed, 84 insertions(+), 89 deletions(-) diff --git a/modulus/models/fno/__init__.py b/modulus/models/fno/__init__.py index adccb972b1..2e65945366 100644 --- a/modulus/models/fno/__init__.py +++ b/modulus/models/fno/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fno import FNO, FNO1DEncoder +from .fno import FNO, FNO1DEncoder, FNO2DEncoder, FNO3DEncoder, FNO4DEncoder, MetaData diff --git a/modulus/models/fno/fno.py b/modulus/models/fno/fno.py index 093f04773e..dda8a20eeb 100644 --- a/modulus/models/fno/fno.py +++ b/modulus/models/fno/fno.py @@ -27,6 +27,7 @@ from ..mlp import FullyConnected from ..module import Module + # =================================================================== # =================================================================== # 1D FNO @@ -57,7 +58,6 @@ class FNO1DEncoder(nn.Module): Use coordinate grid as additional feature map, by default True """ - def __init__( self, in_channels: int = 1, @@ -71,25 +71,6 @@ def __init__( ) -> None: super().__init__() - self.set_vars(in_channels, - num_fno_layers, - fno_layer_size, - padding, - padding_type, - activation_fn, - coord_features) - self.build_lift_network() - self.build_fno(num_fno_modes) - - - def set_vars(self, - in_channels, - num_fno_layers, - fno_layer_size, - padding, - padding_type, - activation_fn, - coord_features): self.in_channels = in_channels self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size @@ -107,11 +88,11 @@ def set_vars(self, self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type - self.spconv_layers = nn.ModuleList() - self.conv_layers = nn.ModuleList() - + # build lift + self.build_lift_network() + self.build_fno(num_fno_modes) - def build_lift_network(self): + def build_lift_network(self) -> None: self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2)) @@ -121,19 +102,19 @@ def build_lift_network(self): layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width) ) - - def build_fno(self, num_fno_modes): + def build_fno(self, num_fno_modes: int) -> None: if isinstance(num_fno_modes, int): num_fno_modes = [num_fno_modes] # Build Neural Fourier Operators + self.spconv_layers = nn.ModuleList() + self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv1d(self.fno_width, self.fno_width, num_fno_modes[0]) ) self.conv_layers.append(nn.Conv1d(self.fno_width, self.fno_width, 1)) - def forward(self, x: Tensor) -> Tensor: if self.coord_features: coord_feat = self.meshgrid(list(x.shape), x.device) @@ -178,7 +159,6 @@ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: output = torch.permute(value, (0, 2, 1)) return output.reshape(-1, output.size(-1)), y_shape - def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: output = value.reshape(shape[0], shape[2], value.size(-1)) return torch.permute(output, (0, 2, 1)) @@ -230,17 +210,28 @@ def __init__( self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features - # Spectral modes to have weights - if isinstance(num_fno_modes, int): - num_fno_modes = [num_fno_modes, num_fno_modes] + self.activation_fn = activation_fn + # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 2 - self.activation_fn = activation_fn - self.spconv_layers = nn.ModuleList() - self.conv_layers = nn.ModuleList() + # Padding values for spectral conv + if isinstance(padding, int): + padding = [padding, padding] + padding = padding + [0, 0] # Pad with zeros for smaller lists + self.pad = padding[:2] + self.ipad = [-pad if pad > 0 else None for pad in self.pad] + self.padding_type = padding_type + + if isinstance(num_fno_modes, int): + num_fno_modes = [num_fno_modes, num_fno_modes] + + # build lift + self.build_lift_network() + self.build_fno(num_fno_modes) + def build_lift_network(self) -> None: # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -251,7 +242,10 @@ def __init__( layers.Conv2dFCLayer(int(self.fno_width / 2), self.fno_width) ) + def build_fno(self, num_fno_modes: int) -> None: # Build Neural Fourier Operators + self.spconv_layers = nn.ModuleList() + self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv2d( @@ -260,14 +254,6 @@ def __init__( ) self.conv_layers.append(nn.Conv2d(self.fno_width, self.fno_width, 1)) - # Padding values for spectral conv - if isinstance(padding, int): - padding = [padding, padding] - padding = padding + [0, 0] # Pad with zeros for smaller lists - self.pad = padding[:2] - self.ipad = [-pad if pad > 0 else None for pad in self.pad] - self.padding_type = padding_type - def forward(self, x: Tensor) -> Tensor: if x.dim() != 4: raise ValueError( @@ -322,7 +308,6 @@ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: output = torch.permute(value, (0, 2, 3, 1)) return output.reshape(-1, output.size(-1)), y_shape - def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: output = value.reshape(shape[0], shape[2], shape[3], value.size(-1)) return torch.permute(output, (0, 3, 1, 2)) @@ -375,17 +360,28 @@ def __init__( self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features - # Spectral modes to have weights - if isinstance(num_fno_modes, int): - num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes] + self.activation_fn = activation_fn + # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 3 - self.activation_fn = activation_fn - self.spconv_layers = nn.ModuleList() - self.conv_layers = nn.ModuleList() + # Padding values for spectral conv + if isinstance(padding, int): + padding = [padding, padding, padding] + padding = padding + [0, 0, 0] # Pad with zeros for smaller lists + self.pad = padding[:3] + self.ipad = [-pad if pad > 0 else None for pad in self.pad] + self.padding_type = padding_type + if isinstance(num_fno_modes, int): + num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes] + + # build lift + self.build_lift_network() + self.build_fno(num_fno_modes) + + def build_lift_network(self) -> None: # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -396,7 +392,10 @@ def __init__( layers.Conv3dFCLayer(int(self.fno_width / 2), self.fno_width) ) + def build_fno(self, num_fno_modes: int) -> None: # Build Neural Fourier Operators + self.spconv_layers = nn.ModuleList() + self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv3d( @@ -409,14 +408,6 @@ def __init__( ) self.conv_layers.append(nn.Conv3d(self.fno_width, self.fno_width, 1)) - # Padding values for spectral conv - if isinstance(padding, int): - padding = [padding, padding, padding] - padding = padding + [0, 0, 0] # Pad with zeros for smaller lists - self.pad = padding[:3] - self.ipad = [-pad if pad > 0 else None for pad in self.pad] - self.padding_type = padding_type - def forward(self, x: Tensor) -> Tensor: if self.coord_features: coord_feat = self.meshgrid(list(x.shape), x.device) @@ -465,13 +456,11 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: grid_z = grid_z.unsqueeze(0).unsqueeze(0).repeat(bsize, 1, 1, 1, 1) return torch.cat((grid_x, grid_y, grid_z), dim=1) - def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 4, 1)) return output.reshape(-1, output.size(-1)), y_shape - def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1)) return torch.permute(output, (0, 4, 1, 2, 3)) @@ -524,17 +513,28 @@ def __init__( self.num_fno_layers = num_fno_layers self.fno_width = fno_layer_size self.coord_features = coord_features - # Spectral modes to have weights - if isinstance(num_fno_modes, int): - num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes, num_fno_modes] + self.activation_fn = activation_fn + # Add relative coordinate feature if self.coord_features: self.in_channels = self.in_channels + 4 - self.activation_fn = activation_fn - self.spconv_layers = nn.ModuleList() - self.conv_layers = nn.ModuleList() + # Padding values for spectral conv + if isinstance(padding, int): + padding = [padding, padding, padding, padding] + padding = padding + [0, 0, 0, 0] # Pad with zeros for smaller lists + self.pad = padding[:4] + self.ipad = [-pad if pad > 0 else None for pad in self.pad] + self.padding_type = padding_type + if isinstance(num_fno_modes, int): + num_fno_modes = [num_fno_modes, num_fno_modes, num_fno_modes, num_fno_modes] + + # build lift + self.build_lift_network() + self.build_fno(num_fno_modes) + + def build_lift_network(self) -> None: # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -545,7 +545,10 @@ def __init__( layers.ConvNdFCLayer(int(self.fno_width / 2), self.fno_width) ) + def build_fno(self, num_fno_modes: int) -> None: # Build Neural Fourier Operators + self.spconv_layers = nn.ModuleList() + self.conv_layers = nn.ModuleList() for _ in range(self.num_fno_layers): self.spconv_layers.append( layers.SpectralConv4d( @@ -561,14 +564,6 @@ def __init__( layers.ConvNdKernel1Layer(self.fno_width, self.fno_width) ) - # Padding values for spectral conv - if isinstance(padding, int): - padding = [padding, padding, padding, padding] - padding = padding + [0, 0, 0, 0] # Pad with zeros for smaller lists - self.pad = padding[:4] - self.ipad = [-pad if pad > 0 else None for pad in self.pad] - self.padding_type = padding_type - def forward(self, x: Tensor) -> Tensor: if self.coord_features: coord_feat = self.meshgrid(list(x.shape), x.device) @@ -632,7 +627,6 @@ def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: output = torch.permute(value, (0, 2, 3, 4, 5, 1)) return output.reshape(-1, output.size(-1)), y_shape - def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: output = value.reshape( shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1) @@ -640,12 +634,6 @@ def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: return torch.permute(output, (0, 5, 1, 2, 3, 4)) -# Functions for converting between point based and grid (image) representations - - - - - # =================================================================== # =================================================================== # General FNO Model @@ -755,6 +743,7 @@ def __init__( self.padding_type = padding_type self.activation_fn = layers.get_activation(activation_fn) self.coord_features = coord_features + self.dimension = dimension # decoder net self.decoder_net = FullyConnected( @@ -765,12 +754,7 @@ def __init__( activation_fn=decoder_activation_fn, ) - if dimension == 1: - FNOModel = self.get_FNO1DEncoder() - else: - raise NotImplementedError( - "Invalid dimensionality. Only 1D, 2D, 3D and 4D FNO implemented" - ) + FNOModel = self.getFNOEncoder() self.spec_encoder = FNOModel( in_channels, @@ -783,8 +767,19 @@ def __init__( coord_features=self.coord_features, ) - def get_FNO1DEncoder(self): # put all dims in here - return FNO1DEncoder + def getFNOEncoder(self): + if self.dimension == 1: + return FNO1DEncoder + elif self.dimension == 2: + return FNO2DEncoder + elif self.dimension == 3: + return FNO3DEncoder + elif self.dimension == 4: + return FNO4DEncoder + else: + raise NotImplementedError( + "Invalid dimensionality. Only 1D, 2D, 3D and 4D FNO implemented" + ) def forward(self, x: Tensor) -> Tensor: # Fourier encoder From 38686d92651d3d87c9abf4dea15d2adbd02f8fc9 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:51:16 +0100 Subject: [PATCH 3/7] linting --- modulus/models/fno/fno.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modulus/models/fno/fno.py b/modulus/models/fno/fno.py index dda8a20eeb..d858bf2d55 100644 --- a/modulus/models/fno/fno.py +++ b/modulus/models/fno/fno.py @@ -27,7 +27,6 @@ from ..mlp import FullyConnected from ..module import Module - # =================================================================== # =================================================================== # 1D FNO From 10514098f642ee9db1a4e6bc9e4cddcdf7a9c737 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Tue, 12 Dec 2023 14:38:06 +0100 Subject: [PATCH 4/7] docstrings for grid_to_point and point_to_grid --- modulus/models/fno/fno.py | 100 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/modulus/models/fno/fno.py b/modulus/models/fno/fno.py index d858bf2d55..82f68325dd 100644 --- a/modulus/models/fno/fno.py +++ b/modulus/models/fno/fno.py @@ -154,11 +154,36 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: return grid_x def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + """converting from grid based (image) to point based representation + + Parameters + ---------- + value : Meshgrid tensor + + Returns + ------- + Tuple + Tensor, meshgrid shape + """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 1)) return output.reshape(-1, output.size(-1)), y_shape def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + """converting from point based to grid based (image) representation + + Parameters + ---------- + value : Tensor + Tensor + shape : List[int] + meshgrid shape + + Returns + ------- + Tensor + Meshgrid tensor + """ output = value.reshape(shape[0], shape[2], value.size(-1)) return torch.permute(output, (0, 2, 1)) @@ -303,11 +328,36 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: return torch.cat((grid_x, grid_y), dim=1) def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + """converting from grid based (image) to point based representation + + Parameters + ---------- + value : Meshgrid tensor + + Returns + ------- + Tuple + Tensor, meshgrid shape + """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 1)) return output.reshape(-1, output.size(-1)), y_shape def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + """converting from point based to grid based (image) representation + + Parameters + ---------- + value : Tensor + Tensor + shape : List[int] + meshgrid shape + + Returns + ------- + Tensor + Meshgrid tensor + """ output = value.reshape(shape[0], shape[2], shape[3], value.size(-1)) return torch.permute(output, (0, 3, 1, 2)) @@ -456,11 +506,36 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: return torch.cat((grid_x, grid_y, grid_z), dim=1) def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + """converting from grid based (image) to point based representation + + Parameters + ---------- + value : Meshgrid tensor + + Returns + ------- + Tuple + Tensor, meshgrid shape + """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 4, 1)) return output.reshape(-1, output.size(-1)), y_shape def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + """converting from point based to grid based (image) representation + + Parameters + ---------- + value : Tensor + Tensor + shape : List[int] + meshgrid shape + + Returns + ------- + Tensor + Meshgrid tensor + """ output = value.reshape(shape[0], shape[2], shape[3], shape[4], value.size(-1)) return torch.permute(output, (0, 4, 1, 2, 3)) @@ -622,11 +697,36 @@ def meshgrid(self, shape: List[int], device: torch.device) -> Tensor: return torch.cat((grid_x, grid_y, grid_z, grid_t), dim=1) def grid_to_points(self, value: Tensor) -> Tuple[Tensor, List[int]]: + """converting from grid based (image) to point based representation + + Parameters + ---------- + value : Meshgrid tensor + + Returns + ------- + Tuple + Tensor, meshgrid shape + """ y_shape = list(value.size()) output = torch.permute(value, (0, 2, 3, 4, 5, 1)) return output.reshape(-1, output.size(-1)), y_shape def points_to_grid(self, value: Tensor, shape: List[int]) -> Tensor: + """converting from point based to grid based (image) representation + + Parameters + ---------- + value : Tensor + Tensor + shape : List[int] + meshgrid shape + + Returns + ------- + Tensor + Meshgrid tensor + """ output = value.reshape( shape[0], shape[2], shape[3], shape[4], shape[5], value.size(-1) ) From faac3df45b6f938427f1f2fc059976d42d3b1b01 Mon Sep 17 00:00:00 2001 From: MarsuPila <22983240+MarsuPila@users.noreply.github.com> Date: Tue, 12 Dec 2023 18:44:18 +0100 Subject: [PATCH 5/7] doc strings for all internal FNO subroutines --- modulus/models/fno/fno.py | 44 +++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/modulus/models/fno/fno.py b/modulus/models/fno/fno.py index 82f68325dd..9981d21d60 100644 --- a/modulus/models/fno/fno.py +++ b/modulus/models/fno/fno.py @@ -87,11 +87,15 @@ def __init__( self.ipad = [-pad if pad > 0 else None for pad in self.pad] self.padding_type = padding_type + if isinstance(num_fno_modes, int): + num_fno_modes = [num_fno_modes] + # build lift self.build_lift_network() self.build_fno(num_fno_modes) def build_lift_network(self) -> None: + """construct network for lifting variables to latent space.""" self.lift_network = torch.nn.Sequential() self.lift_network.append( layers.Conv1dFCLayer(self.in_channels, int(self.fno_width / 2)) @@ -101,10 +105,14 @@ def build_lift_network(self) -> None: layers.Conv1dFCLayer(int(self.fno_width / 2), self.fno_width) ) - def build_fno(self, num_fno_modes: int) -> None: - if isinstance(num_fno_modes, int): - num_fno_modes = [num_fno_modes] + def build_fno(self, num_fno_modes: List[int]) -> None: + """construct FNO block. + Parameters + ---------- + num_fno_modes : List[int] + Number of Fourier modes kept in spectral convolutions + """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() @@ -256,6 +264,7 @@ def __init__( self.build_fno(num_fno_modes) def build_lift_network(self) -> None: + """construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -266,7 +275,14 @@ def build_lift_network(self) -> None: layers.Conv2dFCLayer(int(self.fno_width / 2), self.fno_width) ) - def build_fno(self, num_fno_modes: int) -> None: + def build_fno(self, num_fno_modes: List[int]) -> None: + """construct FNO block. + Parameters + ---------- + num_fno_modes : List[int] + Number of Fourier modes kept in spectral convolutions + + """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() @@ -431,6 +447,7 @@ def __init__( self.build_fno(num_fno_modes) def build_lift_network(self) -> None: + """construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -441,7 +458,14 @@ def build_lift_network(self) -> None: layers.Conv3dFCLayer(int(self.fno_width / 2), self.fno_width) ) - def build_fno(self, num_fno_modes: int) -> None: + def build_fno(self, num_fno_modes: List[int]) -> None: + """construct FNO block. + Parameters + ---------- + num_fno_modes : List[int] + Number of Fourier modes kept in spectral convolutions + + """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() @@ -609,6 +633,7 @@ def __init__( self.build_fno(num_fno_modes) def build_lift_network(self) -> None: + """construct network for lifting variables to latent space.""" # Initial lift network self.lift_network = torch.nn.Sequential() self.lift_network.append( @@ -619,7 +644,14 @@ def build_lift_network(self) -> None: layers.ConvNdFCLayer(int(self.fno_width / 2), self.fno_width) ) - def build_fno(self, num_fno_modes: int) -> None: + def build_fno(self, num_fno_modes: List[int]) -> None: + """construct FNO block. + Parameters + ---------- + num_fno_modes : List[int] + Number of Fourier modes kept in spectral convolutions + + """ # Build Neural Fourier Operators self.spconv_layers = nn.ModuleList() self.conv_layers = nn.ModuleList() From e25b05daa5ee9da8639cc6af3a1e07fe9c819b65 Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:47:14 -0800 Subject: [PATCH 6/7] Update CHANGELOG.md --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e02ade097..a837328e3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,10 +17,12 @@ process group config. - Updated Frechet Inception Distance to use Wasserstein 2-norm with improved stability. - Molecular Dynamics example. -- Updating MLFLow logging such that only proc 0 logs to MLFlow. ### Changed +- MLFLow logging such that only proc 0 logs to MLFlow. +- FNO given seperate methods for constructing lift and FNO encoder layers. + ### Deprecated ### Removed From a91d16b3fc3a4069366cb94692c05d734432948b Mon Sep 17 00:00:00 2001 From: Nicholas Geneva <5533524+NickGeneva@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:47:31 -0800 Subject: [PATCH 7/7] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a837328e3c..d2518b2fa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ stability. ### Changed - MLFLow logging such that only proc 0 logs to MLFlow. -- FNO given seperate methods for constructing lift and FNO encoder layers. +- FNO given seperate methods for constructing lift and spectral encoder layers. ### Deprecated