diff --git a/CHANGELOG.md b/CHANGELOG.md index c5e61284..2c792db3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Keep it human-readable, your future self will thank you! - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) - New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) +- Added `supporting_arrays` argument, which contains arrays to store in checkpoints. [#97](https://github.com/ecmwf/anemoi-models/pull/97) ### Changed - Bugfixes for CI diff --git a/src/anemoi/models/interface/__init__.py b/src/anemoi/models/interface/__init__.py index 25b7852a..261dec29 100644 --- a/src/anemoi/models/interface/__init__.py +++ b/src/anemoi/models/interface/__init__.py @@ -37,6 +37,8 @@ class AnemoiModelInterface(torch.nn.Module): Statistics for the data. metadata : dict Metadata for the model. + supporting_arrays : dict + Numpy arraysto store in the checkpoint. data_indices : dict Indices for the data. pre_processors : Processors @@ -48,7 +50,14 @@ class AnemoiModelInterface(torch.nn.Module): """ def __init__( - self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict + self, + *, + config: DotDict, + graph_data: HeteroData, + statistics: dict, + data_indices: dict, + metadata: dict, + supporting_arrays: dict = None, ) -> None: super().__init__() self.config = config @@ -57,6 +66,7 @@ def __init__( self.graph_data = graph_data self.statistics = statistics self.metadata = metadata + self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {} self.data_indices = data_indices self._build_model()