From 24cafa1f539dc7b072998f9215bbcb8490b466cd Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 6 Jan 2025 13:54:17 +0000 Subject: [PATCH] CDS Input Support --- docs/configs/inputs.rst | 36 +++++++++ docs/configs/inputs_10.yaml | 7 ++ docs/configs/inputs_8.yaml | 3 + docs/configs/inputs_9.yaml | 24 ++++++ src/anemoi/inference/inputs/cds.py | 118 +++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+) create mode 100644 docs/configs/inputs_10.yaml create mode 100644 docs/configs/inputs_8.yaml create mode 100644 docs/configs/inputs_9.yaml create mode 100644 src/anemoi/inference/inputs/cds.py diff --git a/docs/configs/inputs.rst b/docs/configs/inputs.rst index 8343b10..058ed5e 100644 --- a/docs/configs/inputs.rst +++ b/docs/configs/inputs.rst @@ -95,3 +95,39 @@ You can change that to use ERA5 reanalysis data (``class=ea``). The ``mars`` input also accepts the ``namer`` parameter of the GRIB input. + +***** + cds +***** + +You can also specify the input as ``cds`` to read the data from the +[Climate Data Store](https://cds.climate.copernicus.eu/). This requires +the `cdsapi` package to be installed, and the user to have a CDS +account. + +.. literalinclude:: inputs_8.yaml + :language: yaml + +As the CDS contains a plethora of +[datasets](https://cds.climate.copernicus.eu/datasets), you can specify +the dataset you want to use with the key `dataset`. + +This can be a str in which case the dataset is used for all requests, or +a dict of any number of levels which will be descended based on the +key/values for each request. + +You can use `*` to represent any not given value for a key, i.e. set a +dataset for `param: 2t`. and `param: *` to represent any other param. + +.. literalinclude:: inputs_9.yaml + :language: yaml + +In the above example, the dataset `reanalysis-era5-pressure-levels` is +used for all with `levtype: pl` and `reanalysis-era5-single-levels` used +for all with `levtype: sfc`. + +Additionally, any kwarg can be passed to be added to all requests, i.e. +for ERA5 data, `product_type: 'reanalysis'` is needed. + +.. literalinclude:: inputs_10.yaml + :language: yaml diff --git a/docs/configs/inputs_10.yaml b/docs/configs/inputs_10.yaml new file mode 100644 index 0000000..f306b05 --- /dev/null +++ b/docs/configs/inputs_10.yaml @@ -0,0 +1,7 @@ +input: + cds: + dataset: + levtype: + pl: reanalysis-era5-pressure-levels + sfc: reanalysis-era5-single-levels + product_type: 'reanalysis' diff --git a/docs/configs/inputs_8.yaml b/docs/configs/inputs_8.yaml new file mode 100644 index 0000000..00074fc --- /dev/null +++ b/docs/configs/inputs_8.yaml @@ -0,0 +1,3 @@ +input: + cds: + dataset: ??? diff --git a/docs/configs/inputs_9.yaml b/docs/configs/inputs_9.yaml new file mode 100644 index 0000000..6a4d36c --- /dev/null +++ b/docs/configs/inputs_9.yaml @@ -0,0 +1,24 @@ +input: + cds: + # Dataset examples + ## As a string + dataset: + 'reanalysis-era5-pressure-levels' + + ## As a simple dictionary + dataset: + levtype: + pl: reanalysis-era5-pressure-levels + sfc: reanalysis-era5-single-levels + + ## As a complex dictionary + dataset: + stream: + oper: + levtype: + pl: reanalysis-era5-pressure-levels + sfc: reanalysis-era5-single-levels + an: + # ... Other datasets + '*': # Any other stream + # ... Other datasets diff --git a/src/anemoi/inference/inputs/cds.py b/src/anemoi/inference/inputs/cds.py new file mode 100644 index 0000000..3f663cf --- /dev/null +++ b/src/anemoi/inference/inputs/cds.py @@ -0,0 +1,118 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +from earthkit.data.utils.dates import to_datetime + +from . import input_registry +from .grib import GribInput +from .mars import postproc + +LOG = logging.getLogger(__name__) + + +def retrieve(requests, grid, area, dataset, **kwargs): + import earthkit.data as ekd + + def _(r): + mars = r.copy() + for k, v in r.items(): + if isinstance(v, (list, tuple)): + mars[k] = "/".join(str(x) for x in v) + else: + mars[k] = str(v) + + return ",".join(f"{k}={v}" for k, v in mars.items()) + + pproc = postproc(grid, area) + + result = ekd.from_source("empty") + for r in requests: + if isinstance(dataset, str): + d = dataset + elif isinstance(dataset, dict): + # Get dataset from intersection of keys between request and dataset dict + search_dataset = dataset.copy() + while isinstance(search_dataset, dict): + keys = set(r.keys()).intersection(set(search_dataset.keys())) + if len(keys) == 0: + raise KeyError( + f"While searching for dataset, could not find any valid key in dictionary: {r.keys()}, {search_dataset}" + ) + key = list(keys)[0] + if r[key] not in search_dataset[key]: + if "*" in search_dataset[key]: + search_dataset = search_dataset[key]["*"] + continue + + raise KeyError( + f"Dataset dictionary does not contain key {r[key]!r} in {key!r}: {dict(search_dataset[key])}." + ) + search_dataset = search_dataset[key][r[key]] + + d = search_dataset + + r.update(pproc) + r.update(kwargs) + + LOG.debug("%s", _(r)) + result += ekd.from_source("cds", d, r) + + return result + + +@input_registry.register("cds") +class CDSInput(GribInput): + """Get input fields from CDS""" + + def __init__(self, context, *, dataset, namer=None, **kwargs): + super().__init__(context, namer=namer) + + self.variables = self.checkpoint.variables_from_input(include_forcings=False) + self.dataset = dataset + self.kwargs = kwargs + + def create_input_state(self, *, date): + if date is None: + date = to_datetime(-1) + LOG.warning("CDSInput: `date` parameter not provided, using yesterday's date: %s", date) + + date = to_datetime(date) + + return self._create_input_state( + self.retrieve( + self.variables, + [date + h for h in self.checkpoint.lagged], + ), + variables=self.variables, + date=date, + ) + + def retrieve(self, variables, dates): + + requests = self.checkpoint.mars_requests( + variables=variables, + dates=dates, + use_grib_paramid=self.context.use_grib_paramid, + ) + + if not requests: + raise ValueError("No requests for %s (%s)" % (variables, dates)) + + return retrieve( + requests, self.checkpoint.grid, self.checkpoint.area, dataset=self.dataset, expver="0001", **self.kwargs + ) + + def template(self, variable, date, **kwargs): + return self.retrieve([variable], [date])[0] + + def load_forcings(self, variables, dates): + return self._load_forcings(self.retrieve(variables, dates), variables, dates)