From 941814b006256e668eca6e9bbcd1e64816572ecf Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 29 May 2024 10:34:36 +0100 Subject: [PATCH 1/3] Add pyproject formater --- .pre-commit-config.yaml | 6 ++- pyproject.toml | 95 +++++++++++++++++++++-------------------- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c928f08..df1f6f2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,6 @@ repos: - --exit-non-zero-on-fix - --preview - - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v0.9.1 hooks: @@ -69,3 +68,8 @@ repos: hooks: - id: docconvert args: ["numpy"] + +- repo: https://github.com/tox-dev/pyproject-fmt + rev: "2.0.4" + hooks: + - id: pyproject-fmt diff --git a/pyproject.toml b/pyproject.toml index ceb5a1d..aa3add1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,69 +10,72 @@ # https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ [build-system] -requires = ["setuptools>=60", "setuptools-scm>=8.0"] +requires = [ + "setuptools>=60", + "setuptools-scm>=8", +] [project] -description = "A package to hold various functions to support training of ML models." name = "anemoi-inference" -dynamic = ["version"] -license = { file = "LICENSE" } -requires-python = ">=3.9" +description = "A package to hold various functions to support training of ML models." +keywords = [ + "ai", + "inference", + "tools", +] +license = { file = "LICENSE" } authors = [ - { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" }, + { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" }, ] -keywords = ["tools", "inference", "ai"] +requires-python = ">=3.9" classifiers = [ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Operating System :: OS Independent", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ] +dynamic = [ + "version", +] dependencies = [ - "tomli", # Only needed before 3.11 - "anemoi-utils>=0.2.1", - "semantic-version", - "pyyaml", - "torch", - "numpy", + "anemoi-utils>=0.2.1", + "numpy", + "pyyaml", + "semantic-version", + "tomli", # Only needed before 3.11 + "torch", ] -[project.optional-dependencies] - - -docs = [ - # For building the documentation - "sphinx", - "sphinx_rtd_theme", - "nbsphinx", - "pandoc", - "sphinx_argparse", +optional-dependencies.all = [ ] - -all = [] - -dev = [] - -[project.urls] -Homepage = "https://github.com/ecmwf/anemoi-inference/" -Documentation = "https://anemoi-inference.readthedocs.io/" -Repository = "https://github.com/ecmwf/anemoi-inference/" -Issues = "https://github.com/ecmwf/anemoi-inference/issues" +optional-dependencies.dev = [ +] +optional-dependencies.docs = [ + "nbsphinx", + "pandoc", + # For building the documentation + "sphinx", + "sphinx-argparse", + "sphinx-rtd-theme", +] +urls.Documentation = "https://anemoi-inference.readthedocs.io/" +urls.Homepage = "https://github.com/ecmwf/anemoi-inference/" +urls.Issues = "https://github.com/ecmwf/anemoi-inference/issues" # Changelog = "https://github.com/ecmwf/anemoi-inference/CHANGELOG.md" - -[project.scripts] -anemoi-inference = "anemoi.inference.__main__:main" +urls.Repository = "https://github.com/ecmwf/anemoi-inference/" +scripts.anemoi-inference = "anemoi.inference.__main__:main" [tool.setuptools_scm] version_file = "src/anemoi/inference/_version.py" From a4a0ed89ea37869024b4c2d1c325023206456e2f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 29 May 2024 14:56:01 +0000 Subject: [PATCH 2/3] Add support for older checkpoints --- src/anemoi/inference/checkpoint/__init__.py | 8 +- .../inference/checkpoint/metadata/__init__.py | 5 + .../inference/checkpoint/metadata/patch.py | 14 + .../checkpoint/metadata/version_0_0_0.py | 242 ++++++++++++++++++ .../checkpoint/metadata/version_0_1_0.py | 3 + .../checkpoint/metadata/version_0_2_0.py | 19 +- src/anemoi/inference/commands/checkpoint.py | 2 + 7 files changed, 282 insertions(+), 11 deletions(-) create mode 100644 src/anemoi/inference/checkpoint/metadata/version_0_0_0.py diff --git a/src/anemoi/inference/checkpoint/__init__.py b/src/anemoi/inference/checkpoint/__init__.py index 2dc8426..a5da931 100644 --- a/src/anemoi/inference/checkpoint/__init__.py +++ b/src/anemoi/inference/checkpoint/__init__.py @@ -12,6 +12,7 @@ import zipfile from functools import cached_property +from anemoi.utils.checkpoints import has_metadata from anemoi.utils.checkpoints import load_metadata from .metadata import Metadata @@ -30,7 +31,12 @@ def __repr__(self): def __getattr__(self, name): if self._metadata is None: - self._metadata = Metadata.from_metadata(load_metadata(self.path)) + try: + self._metadata = Metadata.from_metadata(load_metadata(self.path)) + except ValueError: + if has_metadata(self.path): + raise + self._metadata = Metadata.from_metadata(None) return getattr(self._metadata, name) diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index 07bf70d..1b2aef9 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -16,10 +16,12 @@ def from_versions(checkpoint_version, dataset_version): + from .version_0_0_0 import Version_0_0_0 from .version_0_1_0 import Version_0_1_0 from .version_0_2_0 import Version_0_2_0 VERSIONS = { + ("0.0.0", "0.0.0"): Version_0_0_0, ("1.0.0", "0.1.0"): Version_0_1_0, ("1.0.0", "0.2.0"): Version_0_2_0, } @@ -56,6 +58,9 @@ def to_dict(self): @classmethod def from_metadata(cls, metadata): + if metadata is None: + metadata = dict(version="0.0.0", dataset=dict(version="0.0.0")) + if isinstance(metadata["dataset"], list): from .patch import list_to_dict diff --git a/src/anemoi/inference/checkpoint/metadata/patch.py b/src/anemoi/inference/checkpoint/metadata/patch.py index 74b64b5..6bf5e35 100644 --- a/src/anemoi/inference/checkpoint/metadata/patch.py +++ b/src/anemoi/inference/checkpoint/metadata/patch.py @@ -16,6 +16,11 @@ def drop_fill(metadata): return metadata +def select_fill(metadata): + metadata["variables"] = [x for x in metadata["forward"]["variables"] if x in metadata["select"]] + return metadata + + def rename_fill(metadata, select): rename = metadata["rename"] @@ -49,6 +54,15 @@ def patch(a, b): } ) + if "select" in a: + return select_fill( + { + "action": "select", + "select": a["select"], + "forward": zarr_fill({"action": "zarr", "attrs": b}), + } + ) + if "rename" in a: return rename_fill( { diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py new file mode 100644 index 0000000..afa8b03 --- /dev/null +++ b/src/anemoi/inference/checkpoint/metadata/version_0_0_0.py @@ -0,0 +1,242 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +from . import Metadata + +LOG = logging.getLogger(__name__) + + +class Version_0_0_0(Metadata): + """ + Reference for very old checkpoints + Will not work and need to be updated + """ + + def __init__(self, metadata): + super().__init__(metadata) + + def dump(self, indent=0): + print("Version_0_0_0: Not implemented") + + # Input + area = [90, 0, -90, 360] + grid = [0.25, 0.25] + param_sfc = [ + "z", + "sp", + "msl", + "lsm", + "sst", + "sdor", + "slor", + "10u", + "10v", + "2t", + "2d", + ] + param_level_pl = ( + ["q", "t", "u", "v", "w", "z"], + [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000], + ) + + ordering = [ + "q_50", + "q_100", + "q_150", + "q_200", + "q_250", + "q_300", + "q_400", + "q_500", + "q_600", + "q_700", + "q_850", + "q_925", + "q_1000", + "t_50", + "t_100", + "t_150", + "t_200", + "t_250", + "t_300", + "t_400", + "t_500", + "t_600", + "t_700", + "t_850", + "t_925", + "t_1000", + "u_50", + "u_100", + "u_150", + "u_200", + "u_250", + "u_300", + "u_400", + "u_500", + "u_600", + "u_700", + "u_850", + "u_925", + "u_1000", + "v_50", + "v_100", + "v_150", + "v_200", + "v_250", + "v_300", + "v_400", + "v_500", + "v_600", + "v_700", + "v_850", + "v_925", + "v_1000", + "w_50", + "w_100", + "w_150", + "w_200", + "w_250", + "w_300", + "w_400", + "w_500", + "w_600", + "w_700", + "w_850", + "w_925", + "w_1000", + "z_50", + "z_100", + "z_150", + "z_200", + "z_250", + "z_300", + "z_400", + "z_500", + "z_600", + "z_700", + "z_850", + "z_925", + "z_1000", + "sp", + "msl", + "sst", + "10u", + "10v", + "2t", + "2d", + "z", + "lsm", + "sdor", + "slor", + ] + + param_format = {"param_level": "{param}{levelist}"} + + computed_constants = [ + "cos_latitude", + "cos_longitude", + "sin_latitude", + "sin_longitude", + ] + computed_constants_mask = [] + + computer_forcing = [ + "cos_julian_day", + "cos_local_time", + "sin_julian_day", + "sin_local_time", + "insolation", + ] + + @property + def variables(self): + return self.ordering + self.computed_constants + self.forcing_params + + @property + def num_input_features(self): + raise NotImplementedError() + + @property + def data_to_model(self): + raise NotImplementedError() + + @property + def model_to_data(self): + raise NotImplementedError() + + ########################################################################### + @property + def order_by(self): + return dict( + valid_datetime="ascending", + param_level=self.ordering, + remapping={"param_level": "{param}_{levelist}"}, + ) + + @property + def select(self): + return dict( + param_level=self.variables, + remapping={"param_level": "{param}_{levelist}"}, + ) + + ########################################################################### + + @property + def constants_from_input(self): + raise NotImplementedError() + + @property + def constants_from_input_mask(self): + raise NotImplementedError() + + @property + def constant_data_from_input_mask(self): + raise NotImplementedError() + + ########################################################################### + + @property + def prognostic_input_mask(self): + raise NotImplementedError() + + @property + def prognostic_data_input_mask(self): + raise NotImplementedError() + + @property + def prognostic_output_mask(self): + raise NotImplementedError() + + @property + def diagnostic_output_mask(self): + raise NotImplementedError() + + @property + def diagnostic_params(self): + raise NotImplementedError() + + @property + def prognostic_params(self): + raise NotImplementedError() + + ########################################################################### + @property + def precision(self): + raise NotImplementedError() + + @property + def multi_step(self): + raise NotImplementedError() + + @property + def imputable_variables(self): + raise NotImplementedError() diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py index 72fceaa..964624c 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_1_0.py @@ -84,3 +84,6 @@ def patch_metadata(self): pl.remove([param, level]) if [param, level] in ml: ml.remove([param, level]) + + def dump(self, indent=0): + print("Version_0_1_0: Not implemented") diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index 470928a..4da04d6 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -84,6 +84,7 @@ def param_level_ml(self): class ZarrRequest(DataRequest): def __init__(self, metadata): + super().__init__(metadata) self.attributes = metadata["attrs"] self.request = self.attributes["data_request"] @@ -143,10 +144,13 @@ class StatisticsRequest(Forward): class RenameRequest(Forward): + # Drop variables + # No need to rename anything as self.metadata["variables"] is already up to date + @property def variables_with_nans(self): - raise NotImplementedError() - return sorted(self.forward.variables_with_nans) + rename = self.metadata["rename"] + return sorted([rename.get(x, x) for x in self.forward.variables_with_nans]) class MultiRequest(Forward): @@ -254,17 +258,12 @@ def variables_with_nans(self): class DropRequest(SelectRequest): - @property - def variables(self): - raise NotImplementedError() + # Drop variables + # No need to drop anything as self.metadata["variables"] is already up to date @property def variables_with_nans(self): - result = set() - for dataset in self.metadata["datasets"]: - result.extend(dataset.variables_with_nans) - - return sorted(result) + return [x for x in self.forward.variables_with_nans if x in self.variables] def data_request(specific): diff --git a/src/anemoi/inference/commands/checkpoint.py b/src/anemoi/inference/commands/checkpoint.py index 423a074..2dbdb5b 100644 --- a/src/anemoi/inference/commands/checkpoint.py +++ b/src/anemoi/inference/commands/checkpoint.py @@ -29,6 +29,8 @@ def run(self, args): c.dump() return + c.dump() + print("area:", c.area) print("computed_constants_mask:", c.computed_constants_mask) print("computed_constants:", c.computed_constants) From 8ac1cfa3a1bf5da09fa1001dd39ac34c1ae9522c Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Wed, 29 May 2024 16:13:18 +0100 Subject: [PATCH 3/3] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8512844..23bbb13 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# anemoi-utils +# anemoi-inference **DISCLAIMER** This project is **BETA** and will be **Experimental** for the foreseeable future.