Skip to content

Commit

Permalink
feat: Add support for output_frequency to write less output (#109)
Browse files Browse the repository at this point in the history
* feat: Add support for `output_frequency` to write less output

---------

Co-authored-by: Dieter Van den Bleeken <[email protected]>
  • Loading branch information
b8raoult and dietervdb-meteo authored Jan 27, 2025
1 parent 8a30d22 commit 828e2b0
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 88 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Add CONTRIBUTORS.md file (#36)
- Add sanetise command
- Add support for huggingface
- Add support for `output_frequency` to write less output
- Added ability to run inference over multiple GPUs [#55](https://github.com/ecmwf/anemoi-inference/pull/55)

### Changed
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def run(self, args):

input_state = input.create_input_state(date=config.date)

if config.write_initial_state:
output.write_initial_state(input_state)
output.write_initial_state(input_state)

for state in runner.run(input_state=input_state, lead_time=config.lead_time):
output.write_state(state)
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Config:
"""Wether to write the initial state to the output file. If the model is multi-step, only fields at the forecast reference date are
written."""

output_frequency: Optional[str] = None
"""The frequency at which to write the output. This can be a string or an integer. If a string, it is parsed by :func:`anemoi.utils.dates.as_timedelta`."""

env: Dict[str, str | int] = {}
"""Environment variables to set before running the model. This may be useful to control some packages
such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended
Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class Context(ABC):
verbosity = 0
development_hacks = {} # For testing purposes, don't use in production

# Some runners will set these values, which can be queried by Output objects,
# but may remain as None

reference_date = None
time_step = None
lead_time = None
output_frequency = None
write_initial_state = True

##################################################################

@property
@abstractmethod
def checkpoint(self):
Expand Down
98 changes: 94 additions & 4 deletions src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,67 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
import logging
from abc import ABC
from abc import abstractmethod
from functools import cached_property

LOG = logging.getLogger(__name__)


class Output(ABC):
"""_summary_"""

def __init__(self, context):
def __init__(self, context, output_frequency=None, write_initial_state=None):

self.context = context
self.checkpoint = context.checkpoint
self.reference_date = None

self._write_step_zero = write_initial_state
self._output_frequency = output_frequency

def __repr__(self):
return f"{self.__class__.__name__}()"

@abstractmethod
def step(self, state):
return state["date"] - self.reference_date

def write_initial_state(self, state):
pass
self._init(state)
if self.write_step_zero:
return self.write_initial_step(state)

@abstractmethod
def write_state(self, state):
self._init(state)

step = self.step(state)
if self.output_frequency is not None:
if (step % self.output_frequency).total_seconds() != 0:
return

return self.write_step(state)

def _init(self, state):
if self.reference_date is not None:
return

self.reference_date = state["date"]

self.open(state)

def write_initial_step(self, state):
"""This method should not be called directly
call `write_initial_state` instead.
"""
reduced_state = self.reduce(state)
self.write_step(reduced_state)

@abstractmethod
def write_step(self, state):
"""This method should not be called directly
call `write_state` instead.
"""
pass

def reduce(self, state):
Expand All @@ -36,5 +77,54 @@ def reduce(self, state):
reduced_state["fields"][field] = values[-1, :]
return reduced_state

def open(self, state):
# Override this method when initialisation is needed
pass

def close(self):
pass

@cached_property
def write_step_zero(self):
if self._write_step_zero is not None:
return self._write_step_zero

return self.context.write_initial_state

@cached_property
def output_frequency(self):
from anemoi.utils.dates import as_timedelta

if self._output_frequency is not None:
return as_timedelta(self._output_frequency)

if self.context.output_frequency is not None:
return as_timedelta(self.context.output_frequency)

return None

def print_summary(self, depth=0):
LOG.info(
"%s%s: output_frequency=%s write_initial_state=%s",
" " * depth,
self,
self.output_frequency,
self.write_step_zero,
)


class ForwardOutput(Output):
"""
Subclass of Output that forwards calls to other outputs
Subclass from that class to implement the desired behaviour of `output_frequency`
which should only apply to leaves
"""

def __init__(self, context, output_frequency=None, write_initial_state=None):
super().__init__(context, output_frequency=None, write_initial_state=write_initial_state)
if self.context.output_frequency is not None:
LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__)

@cached_property
def output_frequency(self):
return None
18 changes: 12 additions & 6 deletions src/anemoi/inference/outputs/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,31 @@

import logging

from ..output import Output
from ..output import ForwardOutput
from . import create_output
from . import output_registry

LOG = logging.getLogger(__name__)


@output_registry.register("apply_mask")
class ApplyMaskOutput(Output):
class ApplyMaskOutput(ForwardOutput):
"""_summary_"""

def __init__(self, context, *, mask, output):
super().__init__(context)
def __init__(self, context, *, mask, output, output_frequency=None, write_initial_state=None):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.mask = self.checkpoint.load_supporting_array(mask)
self.output = create_output(context, output)

def __repr__(self):
return f"ApplyMaskOutput({self.mask}, {self.output})"

def write_initial_state(self, state):
def write_initial_step(self, state):
# Note: we foreward to 'state', so we write-up options again
self.output.write_initial_state(self._apply_mask(state))

def write_state(self, state):
def write_step(self, state):
# Note: we foreward to 'state', so we write-up options again
self.output.write_state(self._apply_mask(state))

def _apply_mask(self, state):
Expand All @@ -52,3 +54,7 @@ def _apply_mask(self, state):

def close(self):
self.output.close()

def print_summary(self, depth=0):
super().print_summary(depth)
self.output.print_summary(depth + 1)
21 changes: 12 additions & 9 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@

import numpy as np

from ..output import Output
from ..output import ForwardOutput
from . import create_output
from . import output_registry

LOG = logging.getLogger(__name__)


@output_registry.register("extract_lam")
class ExtractLamOutput(Output):
class ExtractLamOutput(ForwardOutput):
"""_summary_"""

def __init__(self, context, *, output, lam="lam_0"):
super().__init__(context)

LOG.info("context.checkpoint.supporting_arrays %s", list(context.checkpoint.supporting_arrays.keys()))
LOG.info("%s", len(context.checkpoint.supporting_arrays["grid_indices"]))
def __init__(self, context, *, output, lam="lam_0", output_frequency=None, write_initial_state=None):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)

if "cutout_mask" in self.checkpoint.supporting_arrays:
# Backwards compatibility
Expand All @@ -49,10 +46,12 @@ def __init__(self, context, *, output, lam="lam_0"):
def __repr__(self):
return f"ExtractLamOutput({self.points}, {self.output})"

def write_initial_state(self, state):
def write_initial_step(self, state):
# Note: we foreward to 'state', so we write-up options again
self.output.write_initial_state(self._apply_mask(state))

def write_state(self, state):
def write_step(self, state):
# Note: we foreward to 'state', so we write-up options again
self.output.write_state(self._apply_mask(state))

def _apply_mask(self, state):
Expand All @@ -74,3 +73,7 @@ def _apply_mask(self, state):

def close(self):
self.output.close()

def print_summary(self, depth=0):
super().print_summary(depth)
self.output.print_summary(depth + 1)
21 changes: 16 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,19 @@ class GribOutput(Output):
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
def __init__(
self,
context,
*,
encoding=None,
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_state=None,
):
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
self.quiet = set()
Expand All @@ -88,7 +99,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
def write_initial_step(self, state):
# We trust the GribInput class to provide the templates
# matching the input state

Expand All @@ -98,7 +109,7 @@ def write_initial_state(self, state):
if template is None:
# We can currently only write grib output if we have a grib input
raise ValueError(
"GRIB output only works if the input is GRIB (for now). Set `write_initial_state` to `false`."
"GRIB output only works if the input is GRIB (for now). Set `write_initial_step` to `false`."
)

variable = self.typed_variables[name]
Expand Down Expand Up @@ -128,7 +139,7 @@ def write_initial_state(self, state):

self.write_message(values, template=template, **keys)

def write_state(self, state):
def write_step(self, state):

reference_date = self.context.reference_date
date = state["date"]
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(
grib1_keys=None,
grib2_keys=None,
modifiers=None,
output_frequency=None,
write_initial_state=None,
**kwargs,
):
super().__init__(
Expand All @@ -99,6 +101,8 @@ def __init__(
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
output_frequency=output_frequency,
write_initial_state=write_initial_state,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
Expand Down
Loading

0 comments on commit 828e2b0

Please sign in to comment.