Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better prepml support #89

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ _version.py
*.gif
*.zarr/
/*-plots/
/definitions*/
13 changes: 13 additions & 0 deletions src/anemoi/inference/commands/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def add_arguments(self, command_parser):
command_parser.add_argument("--output", type=str, help="Output file")
command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates")
command_parser.add_argument("--extra", action="append", help="Additional request values. Can be repeated")
command_parser.add_argument("--retrieve-fields-type", type=str, help="Type of fields to retrieve")
command_parser.add_argument("overrides", nargs="*", help="Overrides.")

def run(self, args):
Expand All @@ -45,6 +46,18 @@ def run(self, args):
area = runner.checkpoint.area
grid = runner.checkpoint.grid

if args.retrieve_fields_type is not None:
selected = set()

for name, kinds in runner.checkpoint.variable_categories().items():
if "computed" in kinds:
continue
for kind in kinds:
if args.retrieve_fields_type.startswith(kind): # PrepML adds an 's' to the type
selected.add(name)

variables = sorted(selected)

extra = postproc(grid, area)

for r in args.extra or []:
Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class Config:
development_hacks: dict = {}
"""A dictionary of development hacks to apply to the runner. This is used to test new features or to work around"""

debugging_info: dict = {}
"""A dictionary to store debug information. This is ignored."""


def load_config(path, overrides, defaults=None, Configuration=Configuration):

Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/inference/grib/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ def grib_keys(
result.update(grib2_keys.get(param, {}))

result.setdefault("type", "fc")
type = result.get("type")

if type is not None:
if result.get("type") in ("an", "fc"):
# For organisations that do not use type
result.setdefault("dataType", type)
result.setdefault("dataType", result.pop("type"))

# if stream is not None:
# result.setdefault("stream", stream)
Expand Down
61 changes: 56 additions & 5 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,58 @@
LOG = logging.getLogger(__name__)


class HindcastOutput:

def __init__(self, reference_year):
self.reference_year = reference_year

def __call__(self, values, template, keys):

if "date" not in keys:
assert template.metadata("hdate", default=None) is None, template
date = template.metadata("date")
else:
date = keys.pop("date")

for k in ("date", "hdate"):
keys.pop(k, None)

keys["edition"] = 1
keys["localDefinitionNumber"] = 30
keys["dataDate"] = int(to_datetime(date).strftime("%Y%m%d"))
keys["referenceDate"] = int(to_datetime(date).replace(year=self.reference_year).strftime("%Y%m%d"))

return values, template, keys


MODIFIERS = dict(hindcast=HindcastOutput)


def modifier_factory(modifiers):

if modifiers is None:
return []

if not isinstance(modifiers, list):
modifiers = [modifiers]

result = []
for modifier in modifiers:
assert isinstance(modifier, dict), modifier
assert len(modifier) == 1, modifier

klass = list(modifier.keys())[0]
result.append(MODIFIERS[klass](**modifier[klass]))

return result


class GribOutput(Output):
"""
Handles grib
"""

def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None):
def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, grib2_keys=None, modifiers=None):
super().__init__(context)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
Expand All @@ -40,6 +86,7 @@ def __init__(self, context, *, encoding=None, templates=None, grib1_keys=None, g
self._template_date = None
self._template_reuse = None
self.use_closest_template = False # Off for now
self.modifiers = modifier_factory(modifiers)

def write_initial_state(self, state):
# We trust the GribInput class to provide the templates
Expand Down Expand Up @@ -76,7 +123,8 @@ def write_initial_state(self, state):
quiet=self.quiet,
)

# LOG.info("Step 0 GRIB %s\n%s", template, json.dumps(keys, indent=4))
for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

self.write_message(values, template=template, **keys)

Expand All @@ -95,7 +143,7 @@ def write_state(self, state):
self.quiet.add("_grib_templates_for_output")
LOG.warning("Input is not GRIB.")

for name, value in state["fields"].items():
for name, values in state["fields"].items():
keys = {}

variable = self.typed_variables[name]
Expand All @@ -118,7 +166,7 @@ def write_state(self, state):
keys.update(self.encoding)

keys = grib_keys(
values=value,
values=values,
template=template,
date=reference_date.strftime("%Y-%m-%d"),
time=reference_date.hour,
Expand All @@ -131,11 +179,14 @@ def write_state(self, state):
quiet=self.quiet,
)

for modifier in self.modifiers:
values, template, keys = modifier(values, template, keys)

if LOG.isEnabledFor(logging.DEBUG):
LOG.info("Encoding GRIB %s\n%s", template, json.dumps(keys, indent=4))

try:
self.write_message(value, template=template, **keys)
self.write_message(values, template=template, **keys)
except Exception:
LOG.error("Error writing field %s", name)
LOG.error("Template: %s", template)
Expand Down
25 changes: 24 additions & 1 deletion src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,17 @@ def __init__(
templates=None,
grib1_keys=None,
grib2_keys=None,
modifiers=None,
**kwargs,
):
super().__init__(context, encoding=encoding, templates=templates, grib1_keys=grib1_keys, grib2_keys=grib2_keys)
super().__init__(
context,
encoding=encoding,
templates=templates,
grib1_keys=grib1_keys,
grib2_keys=grib2_keys,
modifiers=modifiers,
)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)
self.archiving = defaultdict(ArchiveCollector)
Expand All @@ -74,6 +82,20 @@ def __repr__(self):
return f"GribFileOutput({self.path})"

def write_message(self, message, template, **keys):
# Make sure `name` is not in the keys, otherwise grib_encoding will fail
if template is not None and template.metadata("name", default=None) is not None:
# We cannot clear the metadata...
class Dummy:
def __init__(self, template):
self.template = template
self.handle = template.handle

def __repr__(self):
return f"Dummy({self.template})"

template = Dummy(template)

# LOG.info("Writing message to %s %s", template, keys)
try:
self.collect_archive_requests(
self.output.write(
Expand All @@ -90,6 +112,7 @@ def write_message(self, message, template, **keys):

LOG.error("Error writing message to %s", self.path)
LOG.error("eccodes: %s", eccodes.__version__)
LOG.error("Template: %s, Keys: %s", template, keys)
LOG.error("Exception: %s", e)
if message is not None and np.isnan(message.data).any():
LOG.error("Message contains NaNs (%s, %s) (allow_nans=%s)", keys, template, self.context.allow_nans)
Expand Down
Loading