From 48cc1200215835f5bf041ef11820686083d73758 Mon Sep 17 00:00:00 2001 From: Gert Mertes <13658335+gmertes@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:09:30 +0000 Subject: [PATCH] Instantiate runner from config and add CRPS runner (#106) * Instantiate runner from config * Add CRPS runner * header year --- src/anemoi/inference/commands/run.py | 4 ++-- src/anemoi/inference/config.py | 3 +++ src/anemoi/inference/runner.py | 7 ++++++- src/anemoi/inference/runners/__init__.py | 4 ++-- src/anemoi/inference/runners/crps.py | 22 ++++++++++++++++++++++ 5 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 src/anemoi/inference/runners/crps.py diff --git a/src/anemoi/inference/commands/run.py b/src/anemoi/inference/commands/run.py index 1ebdd0f8..2922d418 100644 --- a/src/anemoi/inference/commands/run.py +++ b/src/anemoi/inference/commands/run.py @@ -12,7 +12,7 @@ import logging from ..config import load_config -from ..runners.default import DefaultRunner +from ..runners import create_runner from . import Command LOG = logging.getLogger(__name__) @@ -35,7 +35,7 @@ def run(self, args): if config.description is not None: LOG.info("%s", config.description) - runner = DefaultRunner(config) + runner = create_runner(config) input = runner.create_input() output = runner.create_output() diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 8c71924b..199a8c02 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -35,6 +35,9 @@ class Config: checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str] """A path to an Anemoi checkpoint file.""" + runner: str = "default" + """The runner to use.""" + date: Union[str, int, datetime.datetime, None] = None """The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`. """ diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index da609156..4ba0a8af 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -238,6 +238,11 @@ def model(self): # model.set_inference_options(**self.inference_options) return model + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): + # extra args are only used in specific runners + # TODO: move this to a Stepper class. + return model.predict_step(input_tensor_torch) + def forecast(self, lead_time, input_tensor_numpy, input_state): self.model.eval() @@ -283,7 +288,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state): # Predict next state of atmosphere with torch.autocast(device_type=self.device, dtype=self.autocast): - y_pred = self.model.predict_step(input_tensor_torch) + y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s) # Detach tensor and squeeze (should we detach here?) output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables) diff --git a/src/anemoi/inference/runners/__init__.py b/src/anemoi/inference/runners/__init__.py index 2eb86bda..7cc410b6 100644 --- a/src/anemoi/inference/runners/__init__.py +++ b/src/anemoi/inference/runners/__init__.py @@ -11,5 +11,5 @@ runner_registry = Registry(__name__) -def create_runner(context, config): - return runner_registry.from_config(config, context) +def create_runner(config): + return runner_registry.create(config.runner, config) diff --git a/src/anemoi/inference/runners/crps.py b/src/anemoi/inference/runners/crps.py new file mode 100644 index 00000000..137d1946 --- /dev/null +++ b/src/anemoi/inference/runners/crps.py @@ -0,0 +1,22 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +from . import runner_registry +from .default import DefaultRunner + +LOG = logging.getLogger(__name__) + + +@runner_registry.register("crps") +class CrpsRunner(DefaultRunner): + def predict_step(self, model, input_tensor_torch, fcstep, **kwargs): + return model.predict_step(input_tensor_torch, fcstep=fcstep)