diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4de59323..8f820a8d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black args: [--line-length=120] @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.1 hooks: - id: ruff args: @@ -59,17 +59,12 @@ repos: hooks: - id: rstfmt exclude: 'cli/.*' # Because we use argparse -- repo: https://github.com/b8raoult/pre-commit-docconvert - rev: "0.1.5" - hooks: - - id: docconvert - args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig - rev: v0.64.0 + rev: v0.65.0 hooks: - id: docsig args: @@ -79,6 +74,5 @@ repos: - --check-protected # Check protected methods - --check-class # Check class docstrings - --disable=E113 # Disable empty docstrings - - --summary # Print a summary ci: autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index f68f20d4..ac7eae89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.4.0...HEAD) -- Add synchronisation workflow +### Added + +- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37) +- Mask NaN values in training loss function [#56](https://github.com/ecmwf/anemoi-models/pull/56) +- Added dynamic NaN masking for the imputer class with two new classes: DynamicInputImputer, DynamicConstantImputer [#89](https://github.com/ecmwf/anemoi-models/pull/89) +- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84) +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) + +## [0.4.0](https://github.com/ecmwf/anemoi-models/compare/0.3.0...0.4.0) - Improvements to Model Design ### Added +- Add synchronisation workflow [#60](https://github.com/ecmwf/anemoi-models/pull/60) - Add anemoi-transform link to documentation - Codeowners file - Pygrep precommit hooks @@ -22,7 +31,10 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) +- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) +- Add remappers, e.g. link functions to apply during training to facilitate learning of variables with a difficult distribution [#88] +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) ### Changed - Bugfixes for CI @@ -33,6 +45,7 @@ Keep it human-readable, your future self will thank you! - ci: extened python versions to include 3.11 and 3.12 [#66](https://github.com/ecmwf/anemoi-models/pull/66) - Update copyright notice - Fix `__version__` import in init +- Fix missing copyrights [#71](https://github.com/ecmwf/anemoi-models/pull/71) ### Removed diff --git a/docs/conf.py b/docs/conf.py index 8a5d5688..32e7531c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,7 @@ project = "Anemoi Models" -author = "ECMWF" +author = "Anemoi contributors" year = datetime.datetime.now().year if year == 2024: @@ -37,7 +37,7 @@ else: years = "2024-%s" % (year,) -copyright = "%s, ECMWF" % (years,) +copyright = "%s, Anemoi contributors" % (years,) try: from anemoi.models._version import __version__ diff --git a/docs/modules/models.rst b/docs/modules/models.rst index 392a9d61..416257df 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -13,3 +13,29 @@ encoder, processor, and decoder. :members: :no-undoc-members: :show-inheritance: + +********************************************** + Encoder Hierarchical processor Decoder Model +********************************************** + +This model extends the standard encoder-processor-decoder architecture +by introducing a **hierarchical processor**. + +Compared to the AnemoiModelEncProcDec model, this architecture requires +a predefined list of hidden nodes, `[hidden_1, ..., hidden_n]`. These +nodes must be sorted to match the expected flow of information `data -> +hidden_1 -> ... -> hidden_n -> ... -> hidden_1 -> data`. + +A new argument is added to the configuration file: +`enable_hierarchical_level_processing`. This argument determines whether +a processor is added at each hierarchy level or only at the final level. + +By default, the number of channels for the mappers is defined as `2^n * +config.num_channels`, where `n` represents the hierarchy level. This +scaling ensures that the processing capacity grows proportionally with +the depth of the hierarchy, enabling efficient handling of data. + +.. automodule:: anemoi.models.models.hierarchical + :members: + :no-undoc-members: + :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index 214f82c1..6d473472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ - [build-system] build-backend = "setuptools.build_meta" @@ -36,6 +35,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] diff --git a/src/anemoi/models/__init__.py b/src/anemoi/models/__init__.py index bdd630a2..019edd80 100644 --- a/src/anemoi/models/__init__.py +++ b/src/anemoi/models/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/models/__main__.py b/src/anemoi/models/__main__.py index be940c27..0057b940 100644 --- a/src/anemoi/models/__main__.py +++ b/src/anemoi/models/__main__.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser diff --git a/src/anemoi/models/commands/__init__.py b/src/anemoi/models/commands/__init__.py index cebb5395..e5e2219d 100644 --- a/src/anemoi/models/commands/__init__.py +++ b/src/anemoi/models/commands/__init__.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# import os diff --git a/src/anemoi/models/data_indices/__init__.py b/src/anemoi/models/data_indices/__init__.py index e69de29b..c167afa2 100644 --- a/src/anemoi/models/data_indices/__init__.py +++ b/src/anemoi/models/data_indices/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/models/distributed/__init__.py b/src/anemoi/models/distributed/__init__.py index e69de29b..c167afa2 100644 --- a/src/anemoi/models/distributed/__init__.py +++ b/src/anemoi/models/distributed/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index aba62a23..261dec29 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -1,11 +1,11 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# import uuid @@ -37,6 +37,8 @@ class AnemoiModelInterface(torch.nn.Module): Statistics for the data. metadata : dict Metadata for the model. + supporting_arrays : dict + Numpy arraysto store in the checkpoint. data_indices : dict Indices for the data. pre_processors : Processors @@ -48,7 +50,14 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + supporting_arrays: dict = None, ) -> None: super().__init__() self.config = config @@ -57,6 +66,7 @@ def __init__( self.graph_data = graph_data self.statistics = statistics self.metadata = metadata + self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {} self.data_indices = data_indices self._build_model() diff --git a/src/anemoi/models/layers/__init__.py b/src/anemoi/models/layers/__init__.py index e69de29b..c167afa2 100644 --- a/src/anemoi/models/layers/__init__.py +++ b/src/anemoi/models/layers/__init__.py @@ -0,0 +1,8 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index 60446d6c..72e487d2 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -512,8 +512,9 @@ def forward( edge_attr_list, edge_index_list = sort_edges_1hop_chunks( num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks ) + out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device) for i in range(num_chunks): - out1 = self.conv( + out += self.conv( query=query, key=key, value=value, @@ -521,9 +522,6 @@ def forward( edge_index=edge_index_list[i], size=size, ) - if i == 0: - out = torch.zeros_like(out1, device=out1.device) - out = out + out1 else: out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index c7dbefca..a934d32b 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -12,6 +12,7 @@ import torch from torch import Tensor from torch import nn +from torch_geometric.data import HeteroData class TrainableTensor(nn.Module): @@ -36,8 +37,77 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: def forward(self, x: Tensor, batch_size: int) -> Tensor: latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)] if self.trainable is not None: - latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size)) + latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size)) return torch.cat( latent, dim=-1, # feature dimension ) + + +class NamedNodesAttributes(nn.Module): + """Named Nodes Attributes information. + + Attributes + ---------- + num_nodes : dict[str, int] + Number of nodes for each group of nodes. + attr_ndims : dict[str, int] + Total dimension of node attributes (non-trainable + trainable) for each group of nodes. + trainable_tensors : nn.ModuleDict + Dictionary of trainable tensors for each group of nodes. + + Methods + ------- + get_coordinates(self, name: str) -> Tensor + Get the coordinates of a set of nodes. + forward( self, name: str, batch_size: int) -> Tensor + Get the node attributes to be passed trough the graph neural network. + """ + + num_nodes: dict[str, int] + attr_ndims: dict[str, int] + trainable_tensors: dict[str, TrainableTensor] + + def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: + """Initialize NamedNodesAttributes.""" + super().__init__() + + self.define_fixed_attributes(graph_data, num_trainable_params) + + self.trainable_tensors = nn.ModuleDict() + for nodes_name, nodes in graph_data.node_items(): + self.register_coordinates(nodes_name, nodes.x) + self.register_tensor(nodes_name, num_trainable_params) + + def define_fixed_attributes(self, graph_data: HeteroData, num_trainable_params: int) -> None: + """Define fixed attributes.""" + nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in nodes_names} + self.attr_ndims = { + nodes_name: 2 * graph_data[nodes_name].x.shape[1] + num_trainable_params for nodes_name in nodes_names + } + + def register_coordinates(self, name: str, node_coords: Tensor) -> None: + """Register coordinates.""" + sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) + + def get_coordinates(self, name: str) -> Tensor: + """Return original coordinates.""" + sin_cos_coords = getattr(self, f"latlons_{name}") + ndim = sin_cos_coords.shape[1] // 2 + sin_values = sin_cos_coords[:, :ndim] + cos_values = sin_cos_coords[:, ndim:] + return torch.atan2(sin_values, cos_values) + + def register_tensor(self, name: str, num_trainable_params: int) -> None: + """Register a trainable tensor.""" + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], num_trainable_params) + + def forward(self, name: str, batch_size: int) -> Tensor: + """Returns the node attributes to be passed trough the graph neural network. + + It includes both the coordinates and the trainable parameters. + """ + latlons = getattr(self, f"latlons_{name}") + return self.trainable_tensors[name](latlons, batch_size) diff --git a/src/anemoi/models/layers/processor.py b/src/anemoi/models/layers/processor.py index 4fd32311..8dba1f66 100644 --- a/src/anemoi/models/layers/processor.py +++ b/src/anemoi/models/layers/processor.py @@ -323,6 +323,7 @@ def forward( *args, **kwargs, ) -> Tensor: + shape_nodes = change_channels_in_shape(shard_shapes, self.num_channels) edge_attr = self.trainable(self.edge_attr, batch_size) diff --git a/src/anemoi/models/models/__init__.py b/src/anemoi/models/models/__init__.py index e69de29b..2072f12f 100644 --- a/src/anemoi/models/models/__init__.py +++ b/src/anemoi/models/models/__init__.py @@ -0,0 +1,13 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from .encoder_processor_decoder import AnemoiModelEncProcDec +from .hierarchical import AnemoiModelEncProcDecHierarchical + +__all__ = ["AnemoiModelEncProcDec", "AnemoiModelEncProcDecHierarchical"] diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index bdb6260e..c67c8c03 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -22,7 +22,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.graph import NamedNodesAttributes LOGGER = logging.getLogger(__name__) @@ -56,33 +56,24 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = model_config.training.multistep_input - - self._define_tensor_sizes(model_config) - - # Create trainable tensors - self._create_trainable_attributes() - - # Register lat/lon of nodes - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) - self.data_indices = data_indices + self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Processor hidden -> hidden @@ -90,8 +81,8 @@ def __init__( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data @@ -102,8 +93,8 @@ def __init__( hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -133,34 +124,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict) -> None: - self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes - self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden - - def _register_latlon(self, name: str, nodes: str) -> None: - """Register lat/lon buffers. - - Parameters - ---------- - name : str - Name to store the lat-lon coordinates of the nodes. - nodes : str - Name of nodes to map - """ - coords = self._graph_data[nodes].x - sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -210,12 +173,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_data_latent = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), + self.node_attributes(self._graph_name_data, batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size) # get shard shapes shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) diff --git a/src/anemoi/models/models/hierarchical.py b/src/anemoi/models/models/hierarchical.py new file mode 100644 index 00000000..94e82581 --- /dev/null +++ b/src/anemoi/models/models/hierarchical.py @@ -0,0 +1,308 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from typing import Optional + +import einops +import torch +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from torch import Tensor +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup +from torch_geometric.data import HeteroData + +from anemoi.models.distributed.shapes import get_shape_shards +from anemoi.models.layers.graph import NamedNodesAttributes +from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.models import AnemoiModelEncProcDec + +LOGGER = logging.getLogger(__name__) + + +class AnemoiModelEncProcDecHierarchical(AnemoiModelEncProcDec): + """Message passing hierarchical graph neural network.""" + + def __init__( + self, + *, + model_config: DotDict, + data_indices: dict, + graph_data: HeteroData, + ) -> None: + """Initializes the graph neural network. + + Parameters + ---------- + config : DotDict + Job configuration + data_indices : dict + Data indices + graph_data : HeteroData + Graph definition + """ + nn.Module.__init__(self) + + self._graph_data = graph_data + self._graph_name_data = model_config.graph.data + self._graph_hidden_names = model_config.graph.hidden + self.num_hidden = len(self._graph_hidden_names) + + # Unpack config for hierarchical graph + self.level_process = model_config.model.enable_hierarchical_level_processing + + # hidden_dims is the dimentionality of features at each depth + self.hidden_dims = { + hidden: model_config.model.num_channels * (2**i) for i, hidden in enumerate(self._graph_hidden_names) + } + + self._calculate_shapes_and_indices(data_indices) + self._assert_matching_indices(data_indices) + self.data_indices = data_indices + + self.multi_step = model_config.training.multistep_input + + # self.node_attributes = {hidden_name: NamedNodesAttributes(model_config.model.trainable_parameters[hidden_name], self._graph_data) + # for hidden_name in self._graph_hidden_names} + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] + + # Encoder data -> hidden + self.encoder = instantiate( + model_config.model.encoder, + in_channels_src=input_dim, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_hidden_names[0]], + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_hidden_names[0])], + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + ) + + # Level processors + if self.level_process: + self.down_level_processor = nn.ModuleDict() + self.up_level_processor = nn.ModuleDict() + + for i in range(0, self.num_hidden): + nodes_names = self._graph_hidden_names[i] + + self.down_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + self.up_level_processor[nodes_names] = instantiate( + model_config.model.processor, + num_channels=self.hidden_dims[nodes_names], + sub_graph=self._graph_data[(nodes_names, "to", nodes_names)], + src_grid_size=self.node_attributes.num_nodes[nodes_names], + dst_grid_size=self.node_attributes.num_nodes[nodes_names], + num_layers=model_config.model.level_process_num_layers, + ) + + # delete final upscale (does not exist): |->|->|<-|<-| + del self.up_level_processor[nodes_names] + + # Downscale + self.downscale = nn.ModuleDict() + + for i in range(0, self.num_hidden - 1): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i + 1] + + self.downscale[src_nodes_name] = instantiate( + model_config.model.encoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.node_attributes.attr_ndims[dst_nodes_name], + hidden_dim=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Upscale + self.upscale = nn.ModuleDict() + + for i in range(1, self.num_hidden): + src_nodes_name = self._graph_hidden_names[i] + dst_nodes_name = self._graph_hidden_names[i - 1] + + self.upscale[src_nodes_name] = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[src_nodes_name], + in_channels_dst=self.hidden_dims[dst_nodes_name], + hidden_dim=self.hidden_dims[src_nodes_name], + out_channels_dst=self.hidden_dims[dst_nodes_name], + sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)], + src_grid_size=self.node_attributes.num_nodes[src_nodes_name], + dst_grid_size=self.node_attributes.num_nodes[dst_nodes_name], + ) + + # Decoder hidden -> data + self.decoder = instantiate( + model_config.model.decoder, + in_channels_src=self.hidden_dims[self._graph_hidden_names[0]], + in_channels_dst=input_dim, + hidden_dim=self.hidden_dims[self._graph_hidden_names[0]], + out_channels_dst=self.num_output_channels, + sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)], + src_grid_size=self.node_attributes.num_nodes[self._graph_hidden_names[0]], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + ) + + # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) + self.boundings = nn.ModuleList( + [ + instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index) + for cfg in getattr(model_config.model, "bounding", []) + ] + ) + + def _create_trainable_attributes(self) -> None: + """Create all trainable attributes.""" + self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) + self.trainable_hidden = nn.ModuleDict() + + for hidden in self._graph_hidden_names: + self.trainable_hidden[hidden] = TrainableTensor( + trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden] + ) + + def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor: + batch_size = x.shape[0] + ensemble_size = x.shape[2] + + # add data positional info (lat/lon) + x_trainable_data = torch.cat( + ( + einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), + self.node_attributes(self._graph_name_data, batch_size=batch_size), + ), + dim=-1, # feature dimension + ) + + # Get all trainable parameters for the hidden layers -> initialisation of each hidden, which becomes trainable bias + x_trainable_hiddens = {} + for hidden in self._graph_hidden_names: + x_trainable_hiddens[hidden] = self.node_attributes(hidden, batch_size=batch_size) + + # Get data and hidden shapes for sharding + shard_shapes_data = get_shape_shards(x_trainable_data, 0, model_comm_group) + shard_shapes_hiddens = {} + for hidden, x_latent in x_trainable_hiddens.items(): + shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group) + + # Run encoder + x_data_latent, curr_latent = self._run_mapper( + self.encoder, + (x_trainable_data, x_trainable_hiddens[self._graph_hidden_names[0]]), + batch_size=batch_size, + shard_shapes=(shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]]), + model_comm_group=model_comm_group, + ) + + # Run processor + x_encoded_latents = {} + x_skip = {} + + ## Downscale + for i in range(0, self.num_hidden - 1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i + 1] + + # Processing at same level + if self.level_process: + curr_latent = self.down_level_processor[src_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[src_hidden_name], + model_comm_group=model_comm_group, + ) + + # store latents for skip connections + x_skip[src_hidden_name] = curr_latent + + # Encode to next hidden level + x_encoded_latents[src_hidden_name], curr_latent = self._run_mapper( + self.downscale[src_hidden_name], + (curr_latent, x_trainable_hiddens[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Processing hidden-most level + if self.level_process: + curr_latent = self.down_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + ## Upscale + for i in range(self.num_hidden - 1, 0, -1): + src_hidden_name = self._graph_hidden_names[i] + dst_hidden_name = self._graph_hidden_names[i - 1] + + # Process to next level + curr_latent = self._run_mapper( + self.upscale[src_hidden_name], + (curr_latent, x_encoded_latents[dst_hidden_name]), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[src_hidden_name], shard_shapes_hiddens[dst_hidden_name]), + model_comm_group=model_comm_group, + ) + + # Add skip connections + curr_latent = curr_latent + x_skip[dst_hidden_name] + + # Processing at same level + if self.level_process: + curr_latent = self.up_level_processor[dst_hidden_name]( + curr_latent, + batch_size=batch_size, + shard_shapes=shard_shapes_hiddens[dst_hidden_name], + model_comm_group=model_comm_group, + ) + + # Run decoder + x_out = self._run_mapper( + self.decoder, + (curr_latent, x_data_latent), + batch_size=batch_size, + shard_shapes=(shard_shapes_hiddens[self._graph_hidden_names[0]], shard_shapes_data), + model_comm_group=model_comm_group, + ) + + x_out = ( + einops.rearrange( + x_out, + "(batch ensemble grid) vars -> batch ensemble grid vars", + batch=batch_size, + ensemble=ensemble_size, + ) + .to(dtype=x.dtype) + .clone() + ) + + # residual connection (just for the prognostic variables) + x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] + + for bounding in self.boundings: + # bounding performed in the order specified in the config file + x_out = bounding(x_out) + + return x_out diff --git a/src/anemoi/models/preprocessing/__init__.py b/src/anemoi/models/preprocessing/__init__.py index cc2cb4f8..e26505f3 100644 --- a/src/anemoi/models/preprocessing/__init__.py +++ b/src/anemoi/models/preprocessing/__init__.py @@ -1,11 +1,11 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# import logging from typing import Optional @@ -57,25 +57,32 @@ def __init__( super().__init__() - self.default, self.method_config = self._process_config(config) + self.default, self.remap, self.method_config = self._process_config(config) self.methods = self._invert_key_value_list(self.method_config) self.data_indices = data_indices - def _process_config(self, config): + @classmethod + def _process_config(cls, config): _special_keys = ["default", "remap"] # Keys that do not contain a list of variables in a preprocessing method. default = config.get("default", "none") - self.remap = config.get("remap", {}) + remap = config.get("remap", {}) method_config = {k: v for k, v in config.items() if k not in _special_keys and v is not None and v != "none"} if not method_config: LOGGER.warning( - f"{self.__class__.__name__}: Using default method {default} for all variables not specified in the config.", + f"{cls.__name__}: Using default method {default} for all variables not specified in the config.", ) + for m in method_config: + if isinstance(method_config[m], str): + method_config[m] = {method_config[m]: f"{m}_{method_config[m]}"} + elif isinstance(method_config[m], list): + method_config[m] = {method: f"{m}_{method}" for method in method_config[m]} - return default, method_config + return default, remap, method_config - def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[str, str]: + @staticmethod + def _invert_key_value_list(method_config: dict[str, list[str]]) -> dict[str, str]: """Invert a dictionary of methods with lists of variables. Parameters diff --git a/src/anemoi/models/preprocessing/imputer.py b/src/anemoi/models/preprocessing/imputer.py index 2835ef49..0ab57b3e 100644 --- a/src/anemoi/models/preprocessing/imputer.py +++ b/src/anemoi/models/preprocessing/imputer.py @@ -9,6 +9,7 @@ import logging +import warnings from abc import ABC from typing import Optional @@ -43,6 +44,8 @@ def __init__( super().__init__(config, data_indices, statistics) self.nan_locations = None + # weight imputed values wiht zero in loss calculation + self.loss_mask_training = None def _validate_indices(self): assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), ( @@ -104,6 +107,12 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor: """Expand the subset of the mask to the correct shape.""" return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1) + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """get NaN mask from data""" + # The mask is only saved for the last two dimensions (grid, variable) + idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] + return torch.isnan(x[idx].squeeze()) + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: """Impute missing values in the input tensor.""" if not in_place: @@ -115,9 +124,18 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Initialise mask if not cached. if self.nan_locations is None: - # The mask is only saved for the last two dimensions (grid, variable) - idx = [slice(0, 1)] * (x.ndim - 2) + [slice(None), slice(None)] - self.nan_locations = torch.isnan(x[idx].squeeze()) + + # Get NaN locations + self.nan_locations = self.get_nans(x) + + # Initialize training loss mask to weigh imputed values with zeroes once + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) # shape (grid, n_outputs) + # for all variables that are imputed and part of the model output, set the loss weight to zero + for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output): + if idx_dst is not None: + self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int() # Choose correct index based on number of variables if x.shape[-1] == self.num_training_input_vars: @@ -215,3 +233,77 @@ def __init__( self._create_imputation_indices() self._validate_indices() + + +class DynamicMixin: + """Mixin to add dynamic imputation behavior.""" + + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """Override to calculate NaN locations dynamically.""" + return torch.isnan(x) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask every time + nan_locations = self.get_nans(x) + + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][nan_locations[..., idx_src]] = value + + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + return x + + +class DynamicInputImputer(DynamicMixin, InputImputer): + "Imputes missing values using the statistics supplied and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + ) + + +class DynamicConstantImputer(DynamicMixin, ConstantImputer): + "Imputes missing values using the constant value and a dynamic NaN map." + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + warnings.warn( + "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ + The model will be trained to predict imputed values. This might deteriorate performances." + ) diff --git a/src/anemoi/models/preprocessing/mappings.py b/src/anemoi/models/preprocessing/mappings.py new file mode 100644 index 00000000..dab46734 --- /dev/null +++ b/src/anemoi/models/preprocessing/mappings.py @@ -0,0 +1,75 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import torch + + +def noop(x): + """No operation.""" + return x + + +def cos_converter(x): + """Convert angle in degree to cos.""" + return torch.cos(x / 180 * torch.pi) + + +def sin_converter(x): + """Convert angle in degree to sin.""" + return torch.sin(x / 180 * torch.pi) + + +def atan2_converter(x): + """Convert cos and sin to angle in degree. + + Input: + x[..., 0]: cos + x[..., 1]: sin + """ + return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) + + +def log1p_converter(x): + """Convert positive var in to log(1+var).""" + return torch.log1p(x) + + +def boxcox_converter(x, lambd=0.5): + """Convert positive var in to boxcox(var).""" + pos_lam = (torch.pow(x, lambd) - 1) / lambd + null_lam = torch.log(x) + if lambd == 0: + return null_lam + else: + return pos_lam + + +def sqrt_converter(x): + """Convert positive var in to sqrt(var).""" + return torch.sqrt(x) + + +def expm1_converter(x): + """Convert back log(1+var) to var.""" + return torch.expm1(x) + + +def square_converter(x): + """Convert back sqrt(var) to var.""" + return x**2 + + +def inverse_boxcox_converter(x, lambd=0.5): + """Convert back boxcox(var) to var.""" + pos_lam = torch.pow(x * lambd + 1, 1 / lambd) + null_lam = torch.exp(x) + if lambd == 0: + return null_lam + else: + return pos_lam diff --git a/src/anemoi/models/preprocessing/monomapper.py b/src/anemoi/models/preprocessing/monomapper.py new file mode 100644 index 00000000..0359a4c3 --- /dev/null +++ b/src/anemoi/models/preprocessing/monomapper.py @@ -0,0 +1,150 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import boxcox_converter +from anemoi.models.preprocessing.mappings import expm1_converter +from anemoi.models.preprocessing.mappings import inverse_boxcox_converter +from anemoi.models.preprocessing.mappings import log1p_converter +from anemoi.models.preprocessing.mappings import noop +from anemoi.models.preprocessing.mappings import sqrt_converter +from anemoi.models.preprocessing.mappings import square_converter + +LOGGER = logging.getLogger(__name__) + + +class Monomapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["log1p", "sqrt", "boxcox", "none"], + [log1p_converter, sqrt_converter, boxcox_converter, noop], + [expm1_converter, square_converter, inverse_boxcox_converter, noop], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert ( + len(self.index_training_input) + == len(self.index_inference_input) + == len(self.index_inference_output) + == len(self.index_training_out) + == len(self.remappers) + ), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.index_training_input)}, {len(self.index_training_out)}, {len(self.remappers)}" + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.remappers, + self.backmappers, + self.index_training_input, + self.index_training_out, + self.index_inference_input, + self.index_inference_output, + ) = ( + [], + [], + [], + [], + [], + [], + ) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + method = self.methods.get(name, self.default) + if method in self.supported_methods: + self.remappers.append(self.supported_methods[method][0]) + self.backmappers.append(self.supported_methods[method][1]) + self.index_training_input.append(name_to_index_training_input[name]) + if name in name_to_index_training_output: + self.index_training_out.append(name_to_index_training_output[name]) + else: + self.index_training_out.append(None) + if name in name_to_index_inference_input: + self.index_inference_input.append(name_to_index_inference_input[name]) + else: + self.index_inference_input.append(None) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + else: + raise KeyError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_input_vars: + idx = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + idx = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + for i, remapper in zip(idx, self.remappers): + if i is not None: + x[..., i] = remapper(x[..., i]) + return x + + def inverse_transform(self, x, in_place: bool = True) -> torch.Tensor: + if not in_place: + x = x.clone() + if x.shape[-1] == self.num_training_output_vars: + idx = self.index_training_out + elif x.shape[-1] == self.num_inference_output_vars: + idx = self.index_inference_output + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_output_vars}) or inference shape ({self.num_inference_output_vars})", + ) + for i, backmapper in zip(idx, self.backmappers): + if i is not None: + x[..., i] = backmapper(x[..., i]) + return x diff --git a/src/anemoi/models/preprocessing/multimapper.py b/src/anemoi/models/preprocessing/multimapper.py new file mode 100644 index 00000000..f7772e48 --- /dev/null +++ b/src/anemoi/models/preprocessing/multimapper.py @@ -0,0 +1,306 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from abc import ABC +from typing import Optional + +import torch + +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.mappings import atan2_converter +from anemoi.models.preprocessing.mappings import cos_converter +from anemoi.models.preprocessing.mappings import sin_converter + +LOGGER = logging.getLogger(__name__) + + +class Multimapper(BasePreprocessor, ABC): + """Remap single variable to 2 or more variables, or the other way around. + + cos_sin: + Remap the variable to cosine and sine. + Map output back to degrees. + + ``` + cos_sin: + "mwd" : ["cos_mwd", "sin_mwd"] + ``` + """ + + supported_methods = { + method: [f, inv] + for method, f, inv in zip( + ["cos_sin"], + [[cos_converter, sin_converter]], + [atan2_converter], + ) + } + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + """Initialize the remapper. + + Parameters + ---------- + config : DotDict + configuration object of the processor + data_indices : IndexCollection + Data indices for input and output variables + statistics : dict + Data statistics dictionary + """ + super().__init__(config, data_indices, statistics) + self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False + self._create_remapping_indices(statistics) + self._validate_indices() + + def _validate_indices(self): + assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_input)}, " + f"{len(self.index_inference_input)}, {len(self.remappers)}" + ) + assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( + f"Error creating conversion indices {len(self.index_training_output)}, " + f"{len(self.index_inference_output)}, {len(self.remappers)}" + ) + assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( + "Error creating conversion indices: variables remapped in config.data.remapped " + "that have no remapping function defined. Preprocessed tensors contains empty columns." + ) + + def _create_remapping_indices( + self, + statistics=None, + ): + """Create the parameter indices for remapping.""" + # list for training and inference mode as position of parameters can change + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index + name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index + name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index + name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) + self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) + self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) + self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + self.indices_keep_training_input = [] + for key, item in self.data_indices.data.input.name_to_index.items(): + if key in self.data_indices.internal_data.input.name_to_index: + self.indices_keep_training_input.append(item) + self.indices_keep_inference_input = [] + for key, item in self.data_indices.model.input.name_to_index.items(): + if key in self.data_indices.internal_model.input.name_to_index: + self.indices_keep_inference_input.append(item) + self.indices_keep_training_output = [] + for key, item in self.data_indices.data.output.name_to_index.items(): + if key in self.data_indices.internal_data.output.name_to_index: + self.indices_keep_training_output.append(item) + self.indices_keep_inference_output = [] + for key, item in self.data_indices.model.output.name_to_index.items(): + if key in self.data_indices.internal_model.output.name_to_index: + self.indices_keep_inference_output.append(item) + + ( + self.index_training_input, + self.index_training_remapped_input, + self.index_inference_input, + self.index_inference_remapped_input, + self.index_training_output, + self.index_training_backmapped_output, + self.index_inference_output, + self.index_inference_backmapped_output, + self.remappers, + self.backmappers, + ) = ([], [], [], [], [], [], [], [], [], []) + + # Create parameter indices for remapping variables + for name in name_to_index_training_input: + + method = self.methods.get(name, self.default) + + if method == "none": + continue + + if method == "cos_sin": + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output[name]) + self.index_inference_input.append(name_to_index_inference_input[name]) + if name in name_to_index_inference_output: + self.index_inference_output.append(name_to_index_inference_output[name]) + else: + # this is a forcing variable. It is not in the inference output. + self.index_inference_output.append(None) + multiple_training_output, multiple_inference_output = [], [] + multiple_training_input, multiple_inference_input = [], [] + for name_dst in self.method_config[method][name]: + assert name_dst in self.data_indices.internal_data.input.name_to_index, ( + f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " + f"Remap {name} to {name_dst} in config.data.remapped. " + ) + multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) + multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) + multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) + if name_dst in name_to_index_inference_remapped_output: + multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) + else: + # this is a forcing variable. It is not in the inference output. + multiple_inference_output.append(None) + + self.index_training_remapped_input.append(multiple_training_input) + self.index_inference_remapped_input.append(multiple_inference_input) + self.index_training_backmapped_output.append(multiple_training_output) + self.index_inference_backmapped_output.append(multiple_inference_output) + + self.remappers.append([cos_converter, sin_converter]) + self.backmappers.append(atan2_converter) + + LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") + + else: + raise ValueError[f"Unknown remapping method for {name}: {method}"] + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Remap and convert the input tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this preprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + indices_remapped = self.index_training_remapped_input + indices_keep = self.indices_keep_training_input + target_number_columns = self.num_remapped_training_input_vars + + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + indices_remapped = self.index_inference_remapped_input + indices_keep = self.indices_keep_inference_input + target_number_columns = self.num_remapped_inference_input_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_preprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_preprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., : len(indices_keep)] = x[..., indices_keep] + + # Remap variables + for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): + if idx_src is not None: + for jj, ii in enumerate(idx_dst): + x_remapped[..., ii] = remapper[jj](x[..., idx_src]) + + return x_remapped + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Convert and remap the output tensor. + + ``` + x : torch.Tensor + Input tensor + in_place : bool + Whether to process the tensor in place. + in_place is not possible for this postprocessor. + ``` + """ + # Choose correct index based on number of variables + if x.shape[-1] == self.num_remapped_training_output_vars: + index = self.index_training_output + indices_remapped = self.index_training_backmapped_output + indices_keep = self.indices_keep_training_output + target_number_columns = self.num_training_output_vars + + elif x.shape[-1] == self.num_remapped_inference_output_vars: + index = self.index_inference_output + indices_remapped = self.index_inference_backmapped_output + indices_keep = self.indices_keep_inference_output + target_number_columns = self.num_inference_output_vars + + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", + ) + + # create new tensor with target number of columns + x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) + if in_place and not self.printed_postprocessor_warning: + LOGGER.warning( + "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + ) + self.printed_postprocessor_warning = True + + # copy variables that are not remapped + x_remapped[..., indices_keep] = x[..., : len(indices_keep)] + + # Backmap variables + for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): + if idx_dst is not None: + x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) + + return x_remapped + + def transform_loss_mask(self, mask: torch.Tensor) -> torch.Tensor: + """Remap the loss mask. + + ``` + x : torch.Tensor + Loss mask + ``` + """ + # use indices at model output level + index = self.index_inference_backmapped_output + indices_remapped = self.index_inference_output + indices_keep = self.indices_keep_inference_output + + # create new loss mask with target number of columns + mask_remapped = torch.zeros( + mask.shape[:-1] + (mask.shape[-1] + len(indices_remapped),), dtype=mask.dtype, device=mask.device + ) + + # copy loss mask for variables that are not remapped + mask_remapped[..., : len(indices_keep)] = mask[..., indices_keep] + + # remap loss mask for rest of variables + for idx_src, idx_dst in zip(indices_remapped, index): + if idx_dst is not None: + for ii in idx_dst: + mask_remapped[..., ii] = mask[..., idx_src] + + return mask_remapped diff --git a/src/anemoi/models/preprocessing/remapper.py b/src/anemoi/models/preprocessing/remapper.py index cc888222..c3f39c2e 100644 --- a/src/anemoi/models/preprocessing/remapper.py +++ b/src/anemoi/models/preprocessing/remapper.py @@ -12,290 +12,36 @@ from abc import ABC from typing import Optional -import torch - from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.preprocessing import BasePreprocessor +from anemoi.models.preprocessing.monomapper import Monomapper +from anemoi.models.preprocessing.multimapper import Multimapper LOGGER = logging.getLogger(__name__) -def cos_converter(x): - """Convert angle in degree to cos.""" - return torch.cos(x / 180 * torch.pi) - - -def sin_converter(x): - """Convert angle in degree to sin.""" - return torch.sin(x / 180 * torch.pi) - - -def atan2_converter(x): - """Convert cos and sin to angle in degree. - - Input: - x[..., 0]: cos - x[..., 1]: sin - """ - return torch.remainder(torch.atan2(x[..., 1], x[..., 0]) * 180 / torch.pi, 360) - +class Remapper(BasePreprocessor, ABC): + """Remap and convert variables for single variables.""" -class BaseRemapperVariable(BasePreprocessor, ABC): - """Base class for Remapping Variables.""" - - def __init__( - self, + def __new__( + cls, config=None, data_indices: Optional[IndexCollection] = None, statistics: Optional[dict] = None, ) -> None: - """Initialize the remapper. - - Parameters - ---------- - config : DotDict - configuration object of the processor - data_indices : IndexCollection - Data indices for input and output variables - statistics : dict - Data statistics dictionary - """ - super().__init__(config, data_indices, statistics) - - def _validate_indices(self): - assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_input)}, " - f"{len(self.index_inference_input)}, {len(self.remappers)}" - ) - assert len(self.index_training_output) == len(self.index_inference_output) <= len(self.remappers), ( - f"Error creating conversion indices {len(self.index_training_output)}, " - f"{len(self.index_inference_output)}, {len(self.remappers)}" - ) - assert len(set(self.index_training_input + self.indices_keep_training_input)) == self.num_training_input_vars, ( - "Error creating conversion indices: variables remapped in config.data.remapped " - "that have no remapping function defined. Preprocessed tensors contains empty columns." - ) - - def _create_remapping_indices( - self, - statistics=None, - ): - """Create the parameter indices for remapping.""" - # list for training and inference mode as position of parameters can change - name_to_index_training_input = self.data_indices.data.input.name_to_index - name_to_index_inference_input = self.data_indices.model.input.name_to_index - name_to_index_training_remapped_input = self.data_indices.internal_data.input.name_to_index - name_to_index_inference_remapped_input = self.data_indices.internal_model.input.name_to_index - name_to_index_training_remapped_output = self.data_indices.internal_data.output.name_to_index - name_to_index_inference_remapped_output = self.data_indices.internal_model.output.name_to_index - name_to_index_training_output = self.data_indices.data.output.name_to_index - name_to_index_inference_output = self.data_indices.model.output.name_to_index - - self.num_training_input_vars = len(name_to_index_training_input) - self.num_inference_input_vars = len(name_to_index_inference_input) - self.num_remapped_training_input_vars = len(name_to_index_training_remapped_input) - self.num_remapped_inference_input_vars = len(name_to_index_inference_remapped_input) - self.num_remapped_training_output_vars = len(name_to_index_training_remapped_output) - self.num_remapped_inference_output_vars = len(name_to_index_inference_remapped_output) - self.num_training_output_vars = len(name_to_index_training_output) - self.num_inference_output_vars = len(name_to_index_inference_output) - self.indices_keep_training_input = [] - for key, item in self.data_indices.data.input.name_to_index.items(): - if key in self.data_indices.internal_data.input.name_to_index: - self.indices_keep_training_input.append(item) - self.indices_keep_inference_input = [] - for key, item in self.data_indices.model.input.name_to_index.items(): - if key in self.data_indices.internal_model.input.name_to_index: - self.indices_keep_inference_input.append(item) - self.indices_keep_training_output = [] - for key, item in self.data_indices.data.output.name_to_index.items(): - if key in self.data_indices.internal_data.output.name_to_index: - self.indices_keep_training_output.append(item) - self.indices_keep_inference_output = [] - for key, item in self.data_indices.model.output.name_to_index.items(): - if key in self.data_indices.internal_model.output.name_to_index: - self.indices_keep_inference_output.append(item) - - ( - self.index_training_input, - self.index_training_remapped_input, - self.index_inference_input, - self.index_inference_remapped_input, - self.index_training_output, - self.index_training_backmapped_output, - self.index_inference_output, - self.index_inference_backmapped_output, - self.remappers, - self.backmappers, - ) = ([], [], [], [], [], [], [], [], [], []) - - # Create parameter indices for remapping variables - for name in name_to_index_training_input: - - method = self.methods.get(name, self.default) - - if method == "none": - continue - - if method == "cos_sin": - self.index_training_input.append(name_to_index_training_input[name]) - self.index_training_output.append(name_to_index_training_output[name]) - self.index_inference_input.append(name_to_index_inference_input[name]) - if name in name_to_index_inference_output: - self.index_inference_output.append(name_to_index_inference_output[name]) - else: - # this is a forcing variable. It is not in the inference output. - self.index_inference_output.append(None) - multiple_training_output, multiple_inference_output = [], [] - multiple_training_input, multiple_inference_input = [], [] - for name_dst in self.method_config[method][name]: - assert name_dst in self.data_indices.internal_data.input.name_to_index, ( - f"Trying to remap {name} to {name_dst}, but {name_dst} not a variable. " - f"Remap {name} to {name_dst} in config.data.remapped. " - ) - multiple_training_input.append(name_to_index_training_remapped_input[name_dst]) - multiple_training_output.append(name_to_index_training_remapped_output[name_dst]) - multiple_inference_input.append(name_to_index_inference_remapped_input[name_dst]) - if name_dst in name_to_index_inference_remapped_output: - multiple_inference_output.append(name_to_index_inference_remapped_output[name_dst]) - else: - # this is a forcing variable. It is not in the inference output. - multiple_inference_output.append(None) - - self.index_training_remapped_input.append(multiple_training_input) - self.index_inference_remapped_input.append(multiple_inference_input) - self.index_training_backmapped_output.append(multiple_training_output) - self.index_inference_backmapped_output.append(multiple_inference_output) - - self.remappers.append([cos_converter, sin_converter]) - self.backmappers.append(atan2_converter) - - LOGGER.info(f"Map {name} to cosine and sine and save result in {self.method_config[method][name]}.") - - else: - raise ValueError[f"Unknown remapping method for {name}: {method}"] - - def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Remap and convert the input tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this preprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_training_input_vars: - index = self.index_training_input - indices_remapped = self.index_training_remapped_input - indices_keep = self.indices_keep_training_input - target_number_columns = self.num_remapped_training_input_vars - - elif x.shape[-1] == self.num_inference_input_vars: - index = self.index_inference_input - indices_remapped = self.index_inference_remapped_input - indices_keep = self.indices_keep_inference_input - target_number_columns = self.num_remapped_inference_input_vars - + _, _, method_config = cls._process_config(config) + monomappings = Monomapper.supported_methods + multimappings = Multimapper.supported_methods + if all(method in monomappings for method in method_config): + return Monomapper(config, data_indices, statistics) + elif all(method in multimappings for method in method_config): + return Multimapper(config, data_indices, statistics) + elif not ( + any(method in monomappings for method in method_config) + or any(method in multimappings for method in method_config) + ): + raise ValueError("No valid remapping method found.") else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_preprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", + raise NotImplementedError( + f"Not implemented: method_config contains a mix of monomapper and multimapper methods: {list(method_config.keys())}" ) - self.printed_preprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., : len(indices_keep)] = x[..., indices_keep] - - # Remap variables - for idx_dst, remapper, idx_src in zip(indices_remapped, self.remappers, index): - if idx_src is not None: - for jj, ii in enumerate(idx_dst): - x_remapped[..., ii] = remapper[jj](x[..., idx_src]) - - return x_remapped - - def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: - """Convert and remap the output tensor. - - ``` - x : torch.Tensor - Input tensor - in_place : bool - Whether to process the tensor in place. - in_place is not possible for this postprocessor. - ``` - """ - # Choose correct index based on number of variables - if x.shape[-1] == self.num_remapped_training_output_vars: - index = self.index_training_output - indices_remapped = self.index_training_backmapped_output - indices_keep = self.indices_keep_training_output - target_number_columns = self.num_training_output_vars - - elif x.shape[-1] == self.num_remapped_inference_output_vars: - index = self.index_inference_output - indices_remapped = self.index_inference_backmapped_output - indices_keep = self.indices_keep_inference_output - target_number_columns = self.num_inference_output_vars - - else: - raise ValueError( - f"Input tensor ({x.shape[-1]}) does not match the training " - f"({self.num_remapped_training_output_vars}) or inference shape ({self.num_remapped_inference_output_vars})", - ) - - # create new tensor with target number of columns - x_remapped = torch.zeros(x.shape[:-1] + (target_number_columns,), dtype=x.dtype, device=x.device) - if in_place and not self.printed_postprocessor_warning: - LOGGER.warning( - "Remapper (preprocessor) called with in_place=True. This preprocessor cannot be applied in_place as new columns are added to the tensors.", - ) - self.printed_postprocessor_warning = True - - # copy variables that are not remapped - x_remapped[..., indices_keep] = x[..., : len(indices_keep)] - - # Backmap variables - for idx_dst, backmapper, idx_src in zip(index, self.backmappers, indices_remapped): - if idx_dst is not None: - x_remapped[..., idx_dst] = backmapper(x[..., idx_src]) - - return x_remapped - - -class Remapper(BaseRemapperVariable): - """Remap and convert variables. - - cos_sin: - Remap the variable to cosine and sine. - Map output back to degrees. - - ``` - cos_sin: - "mwd" : ["cos_mwd", "sin_mwd"] - ``` - """ - - def __init__( - self, - config=None, - data_indices: Optional[IndexCollection] = None, - statistics: Optional[dict] = None, - ) -> None: - super().__init__(config, data_indices, statistics) - - self.printed_preprocessor_warning, self.printed_postprocessor_warning = False, False - - self._create_remapping_indices(statistics) - - self._validate_indices() diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index 58674bd5..66456d6c 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -8,10 +8,14 @@ # nor does it submit to any jurisdiction. +import einops +import numpy as np import pytest import torch from torch import nn +from torch_geometric.data import HeteroData +from anemoi.models.layers.graph import NamedNodesAttributes from anemoi.models.layers.graph import TrainableTensor @@ -62,3 +66,78 @@ def test_forward_no_trainable(self, init, x): batch_size = 5 output = trainable_tensor(x, batch_size) assert output.shape == (batch_size * x.shape[0], tensor_size + trainable_size) + + +class TestNamedNodesAttributes: + """Test suite for the NamedNodesAttributes class. + + This class contains test cases to verify the functionality of the NamedNodesAttributes class, + including initialization, attribute registration, and forward pass operations. + """ + + nodes_names: list[str] = ["nodes1", "nodes2"] + ndim: int = 2 + num_trainable_params: int = 8 + + @pytest.fixture + def graph_data(self): + graph = HeteroData() + for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names): + graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i + 1)) + return graph + + @staticmethod + def get_n_random_coords(n: int) -> torch.Tensor: + coords = torch.rand(n, TestNamedNodesAttributes.ndim) + coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2) + coords[:, 1] = 2 * np.pi * coords[:, 1] + return coords + + @pytest.fixture + def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes: + return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data) + + def test_init(self, nodes_attributes): + assert isinstance(nodes_attributes, NamedNodesAttributes) + + for nodes_name in self.nodes_names: + assert isinstance(nodes_attributes.num_nodes[nodes_name], int) + assert ( + nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim + == TestNamedNodesAttributes.num_trainable_params + ) + assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor) + + def test_forward(self, nodes_attributes, graph_data): + batch_size = 3 + for nodes_name in self.nodes_names: + output = nodes_attributes(nodes_name, batch_size) + + expected_shape = ( + batch_size * graph_data[nodes_name].num_nodes, + 2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params, + ) + assert output.shape == expected_shape + + # Check if the first part of the output matches the sin-cos transformed coordinates + latlons = getattr(nodes_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output[:, : 2 * TestNamedNodesAttributes.ndim], repeated_latlons) + + # Check if the last part of the output is trainable (requires grad) + assert output[:, 2 * TestNamedNodesAttributes.ndim :].requires_grad + + def test_forward_no_trainable(self, graph_data): + no_trainable_attributes = NamedNodesAttributes(0, graph_data) + batch_size = 2 + + for nodes_name in self.nodes_names: + output = no_trainable_attributes(nodes_name, batch_size) + + expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim + assert output.shape == expected_shape + + # Check if the output exactly matches the sin-cos transformed coordinates + latlons = getattr(no_trainable_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output, repeated_latlons) diff --git a/tests/preprocessing/test_preprocessor_imputer.py b/tests/preprocessing/test_preprocessor_imputer.py index d0df676a..8557b64b 100644 --- a/tests/preprocessing/test_preprocessor_imputer.py +++ b/tests/preprocessing/test_preprocessor_imputer.py @@ -269,6 +269,26 @@ def test_mask_saving(imputer_fixture, data_fixture, request): ("imputer_fixture", "data_fixture"), fixture_combinations, ) +def test_loss_nan_mask(imputer_fixture, data_fixture, request): + """Check that the imputer correctly transforms a tensor with NaNs.""" + x, _ = request.getfixturevalue(data_fixture) + expected = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 1.0]]) # only prognostic and diagnostic variables + imputer = request.getfixturevalue(imputer_fixture) + imputer.transform(x) + assert torch.allclose( + imputer.loss_mask_training, expected + ), "Transform does not calculate NaN-mask for loss function scaling correctly." + + +@pytest.mark.parametrize( + ("imputer_fixture", "data_fixture"), + [ + ("default_constant_imputer", "default_constant_data"), + ("non_default_constant_imputer", "non_default_constant_data"), + ("default_input_imputer", "default_input_data"), + ("non_default_input_imputer", "non_default_input_data"), + ], +) def test_reuse_imputer(imputer_fixture, data_fixture, request): """Check that the imputer reuses the mask correctly on subsequent runs.""" x, expected = request.getfixturevalue(data_fixture) diff --git a/tests/preprocessing/test_preprocessor_remapper.py b/tests/preprocessing/test_preprocessor_remapper.py index a0ece2a3..6d27f906 100644 --- a/tests/preprocessing/test_preprocessor_remapper.py +++ b/tests/preprocessing/test_preprocessor_remapper.py @@ -8,11 +8,13 @@ # nor does it submit to any jurisdiction. +import numpy as np import pytest import torch from omegaconf import DictConfig from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.preprocessing.imputer import InputImputer from anemoi.models.preprocessing.remapper import Remapper @@ -41,22 +43,82 @@ def input_remapper(): return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) +@pytest.fixture() +def input_remapper_1d(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "log1p": "d", + "sqrt": "q", + }, + "forcing": ["z", "q"], + "diagnostic": ["other"], + }, + }, + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return Remapper(config=config.data.remapper, data_indices=data_indices, statistics=statistics) + + +@pytest.fixture() +def input_imputer(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": { + "cos_sin": { + "d": ["cos_d", "sin_d"], + } + }, + "imputer": {"default": "none", "mean": ["y", "d"]}, + "forcing": ["z", "q"], + "diagnostic": ["other"], + "remapped": { + "d": ["cos_d", "sin_d"], + }, + }, + }, + ) + statistics = { + "mean": np.array([1.0, 2.0, 3.0, 4.5, 3.0, 1.0]), + } + name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "d": 4, "other": 5} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + return InputImputer(config=config.data.imputer, data_indices=data_indices, statistics=statistics) + + def test_remap_not_inplace(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) input_remapper(x, in_place=False) - assert torch.allclose(x, torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]])) + assert torch.allclose( + x, + torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]), + ) def test_remap(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) expected_output = torch.Tensor( - [[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]] + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] ) assert torch.allclose(input_remapper.transform(x), expected_output) def test_inverse_transform(input_remapper) -> None: - x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795]]) + x = torch.Tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -0.8660254, 0.5], + [6.0, 7.0, 8.0, 9.0, 10.0, -0.93358043, -0.35836795], + ] + ) expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose(input_remapper.inverse_transform(x), expected_output) @@ -64,5 +126,77 @@ def test_inverse_transform(input_remapper) -> None: def test_remap_inverse_transform(input_remapper) -> None: x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) assert torch.allclose( - input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), x + input_remapper.inverse_transform(input_remapper.transform(x, in_place=False), in_place=False), + x, + ) + + +def test_transform_loss_mask(input_imputer, input_remapper) -> None: + x = torch.Tensor([[1.0, np.nan, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, np.nan, 10.0]]) + expected_output = torch.Tensor([[1.0, 0.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0, 0.0]]) + input_imputer.transform(x) + input_remapper.transform(x) + loss_mask_training = input_imputer.loss_mask_training + loss_mask_training = input_remapper.transform_loss_mask(loss_mask_training) + assert torch.allclose(loss_mask_training, expected_output) + + +def test_monomap_transform(input_remapper_1d) -> None: + x = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + expected_output = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.transform(x, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.transform( + x[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["input"]["full"]], ) + # this one actually changes the values in x so need to be last + assert torch.allclose(input_remapper_1d.transform(x), expected_output) + + +def test_monomap_inverse_transform(input_remapper_1d) -> None: + expected_output = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 150.0, 5.0], [6.0, 7.0, 8.0, 9.0, 201.0, 10.0]]) + y = torch.Tensor( + [ + [1.0, 2.0, 3.0, np.sqrt(4.0), np.log1p(150.0), 5.0], + [6.0, 7.0, 8.0, np.sqrt(9.0), np.log1p(201.0), 10.0], + ] + ) + assert torch.allclose(input_remapper_1d.inverse_transform(y, in_place=False), expected_output) + # inference mode (without prognostic variables) + assert torch.allclose( + input_remapper_1d.inverse_transform( + y[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], in_place=False + ), + expected_output[..., input_remapper_1d.data_indices.data.todict()["output"]["full"]], + ) + + +def test_unsupported_remapper(): + config = DictConfig( + { + "diagnostics": {"log": {"code": {"level": "DEBUG"}}}, + "data": { + "remapper": {"log1p": "q", "cos_sin": "d"}, + "forcing": [], + "diagnostic": [], + }, + } + ) + statistics = {} + name_to_index = {"x": 0, "y": 1, "q": 2, "d": 3} + data_indices = IndexCollection(config=config, name_to_index=name_to_index) + + with pytest.raises(NotImplementedError): + Remapper( + config=config.data.remapper, + data_indices=data_indices, + statistics=statistics, + )