Skip to content

Commit

Permalink
Instantiate runner from config and add CRPS runner (#106)
Browse files Browse the repository at this point in the history
* Instantiate runner from config

* Add CRPS runner

* header year
  • Loading branch information
gmertes authored Jan 16, 2025
1 parent 700a274 commit 48cc120
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import logging

from ..config import load_config
from ..runners.default import DefaultRunner
from ..runners import create_runner
from . import Command

LOG = logging.getLogger(__name__)
Expand All @@ -35,7 +35,7 @@ def run(self, args):
if config.description is not None:
LOG.info("%s", config.description)

runner = DefaultRunner(config)
runner = create_runner(config)

input = runner.create_input()
output = runner.create_output()
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Config:
checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str]
"""A path to an Anemoi checkpoint file."""

runner: str = "default"
"""The runner to use."""

date: Union[str, int, datetime.datetime, None] = None
"""The starting date for the forecast. If not provided, the date will depend on the selected Input object. If a string, it is parsed by :func:`anemoi.utils.dates.as_datetime`.
"""
Expand Down
7 changes: 6 additions & 1 deletion src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def model(self):
# model.set_inference_options(**self.inference_options)
return model

def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
# extra args are only used in specific runners
# TODO: move this to a Stepper class.
return model.predict_step(input_tensor_torch)

def forecast(self, lead_time, input_tensor_numpy, input_state):
self.model.eval()

Expand Down Expand Up @@ -283,7 +288,7 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):

# Predict next state of atmosphere
with torch.autocast(device_type=self.device, dtype=self.autocast):
y_pred = self.model.predict_step(input_tensor_torch)
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s)

# Detach tensor and squeeze (should we detach here?)
output = np.squeeze(y_pred.cpu().numpy()) # shape: (values, variables)
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/inference/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
runner_registry = Registry(__name__)


def create_runner(context, config):
return runner_registry.from_config(config, context)
def create_runner(config):
return runner_registry.create(config.runner, config)
22 changes: 22 additions & 0 deletions src/anemoi/inference/runners/crps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

from . import runner_registry
from .default import DefaultRunner

LOG = logging.getLogger(__name__)


@runner_registry.register("crps")
class CrpsRunner(DefaultRunner):
def predict_step(self, model, input_tensor_torch, fcstep, **kwargs):
return model.predict_step(input_tensor_torch, fcstep=fcstep)

0 comments on commit 48cc120

Please sign in to comment.