From 3cf43b737bd8353e8c172b61df40ba79a3d21f15 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 29 May 2024 07:47:22 +0100 Subject: [PATCH] Tidy up code --- .../inference/checkpoint/metadata/__init__.py | 37 +---- .../checkpoint/metadata/version_0_2_0.py | 152 ++++++++++-------- src/anemoi/inference/commands/checkpoint.py | 20 +-- 3 files changed, 94 insertions(+), 115 deletions(-) diff --git a/src/anemoi/inference/checkpoint/metadata/__init__.py b/src/anemoi/inference/checkpoint/metadata/__init__.py index caaff64..07bf70d 100644 --- a/src/anemoi/inference/checkpoint/metadata/__init__.py +++ b/src/anemoi/inference/checkpoint/metadata/__init__.py @@ -305,42 +305,7 @@ def multi_step(self): @cached_property def imputable_variables(self): - result = [] - - def from_input_imputer(config): - for k, v in config.items(): - if not isinstance(v, list): - v = [v] - yield from v - - def from_constant_imputer(config): - yield from config.keys() - - def empty(config): - return [] - - IMPUTERS = { - "aifs.preprocessing.imputer.InputImputer": from_input_imputer, - "aifs.preprocessing.imputer.ConstantImputer": from_constant_imputer, - } - - for k, v in self._config_data.get("processors", {}).items(): - target = v.get("_target_") - if target is None: - continue - - if "imputer" in target.lower() and target not in IMPUTERS: - LOG.warning("Unknown imputer %s, ignoring", target) - continue - - source = IMPUTERS.get(target, empty) - result.extend(source(v.get("config", {}))) - - result = sorted(set(result)) - if result: - LOG.info("Imputable variables %s", result) - - return result + return self.variables_with_nans def rounded_area(self, area): surface = (area[0] - area[2]) * (area[3] - area[1]) / 180 / 360 diff --git a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py index b83939a..470928a 100644 --- a/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py +++ b/src/anemoi/inference/checkpoint/metadata/version_0_2_0.py @@ -25,6 +25,62 @@ def variables(self): def __repr__(self) -> str: return self.__class__.__name__ + def mars_request(self): + + def _as_list(v): + if isinstance(v, list): + return v + return [v] + + def _as_string(r): + r = {k: "/".join([str(x) for x in _as_list(v)]) for k, v in r.items() if v} + return ",".join([f"{k}={v}" for k, v in r.items()]) + + r = dict(grid=self.grid, area=self.area) + yield _as_string(r) + + r = dict(param=_as_list(self.param_sfc)) + yield _as_string(r) + + param, pl = self.param_level_pl + if param: + r = dict(param=param, level=pl) + yield _as_string(r) + + param, ml = self.param_level_ml + if param: + r = dict(param=param, level=ml) + yield _as_string(r) + + nans = self.variables_with_nans + if nans: + r = dict(with_nans=nans) + yield _as_string(r) + + def dump_content(self, indent=0): + print() + print(" " * indent, "-", self) + for n in self.mars_request(): + print(" " * indent, " ", n) + + @property + def param_sfc(self): + param_sfc = self.forward.param_sfc + param_step_sfc = [p[0] for p in self.forward.param_step_sfc_pairs] + return [p for p in param_sfc if p not in param_step_sfc] + + @property + def param_level_pl(self): + params = set([p[0] for p in self.param_level_pl_pairs]) + levels = set([p[1] for p in self.param_level_pl_pairs]) + return sorted(params), sorted(levels) + + @property + def param_level_ml(self): + params = set([p[0] for p in self.param_level_ml_pairs]) + levels = set([p[1] for p in self.param_level_ml_pairs]) + return sorted(params), sorted(levels) + class ZarrRequest(DataRequest): def __init__(self, metadata): @@ -44,24 +100,23 @@ def param_sfc(self): return self.request["param_level"].get("sfc", []) @property - def param_level_pl(self): + def param_level_pl_pairs(self): return self.request["param_level"].get("pl", []) @property - def param_level_ml(self): + def param_level_ml_pairs(self): return self.request["param_level"].get("ml", []) @property - def param_step_sfc(self): + def param_step_sfc_pairs(self): return self.request["param_step"].get("sfc", []) @property def variables_with_nans(self): return sorted(self.attributes.get("variables_with_nans", [])) - def dump(self, indent): - print(" " * indent, self) - print(" " * indent, self.request) + def dump(self, indent=0): + self.dump_content(indent) class Forward(DataRequest): @@ -72,8 +127,8 @@ def forward(self): def __getattr__(self, name): return getattr(self.forward, name) - def dump(self, indent): - print(" " * indent, self) + def dump(self, indent=0): + self.dump_content(indent) self.forward.dump(indent + 2) @@ -88,13 +143,10 @@ class StatisticsRequest(Forward): class RenameRequest(Forward): - @property - def variables(self): - raise NotImplementedError() - @property def variables_with_nans(self): raise NotImplementedError() + return sorted(self.forward.variables_with_nans) class MultiRequest(Forward): @@ -106,8 +158,8 @@ def __init__(self, metadata): def forward(self): return self.datasets[0] - def dump(self, indent): - print(" " * indent, self) + def dump(self, indent=0): + self.dump_content(indent) for dataset in self.datasets: dataset.dump(indent + 2) @@ -123,10 +175,10 @@ def param_sfc(self): return result @property - def param_level_pl(self): + def param_level_pl_pairs(self): result = [] for dataset in self.datasets: - for param in dataset.param_level_pl: + for param in dataset.param_level_pl_pairs: if param not in result: result.append(param) return result @@ -134,25 +186,21 @@ def param_level_pl(self): @property def param_level_ml(self): result = [] - for dataset in self.datasets: - for param in dataset.param_level_ml: + for dataset in self.datasets_pairs: + for param in dataset.param_level_ml_pairs: if param not in result: result.append(param) return result @property - def param_step_sfc(self): + def param_step_sfc_pairs(self): result = [] for dataset in self.datasets: - for param in dataset.param_step_sfc: + for param in dataset.param_step_sfc_pairs: if param not in result: result.append(param) return result - @property - def variables(self): - raise NotImplementedError() - @property def variables_with_nans(self): result = set() @@ -184,20 +232,20 @@ def param_sfc(self): return [x for x in self.forward.param_sfc if x in self.variables] @property - def param_level_pl(self): - return [x for x in self.forward.param_level_pl if f"{x[0]}_{x[1]}" in self.variables] + def param_level_pl_pairs(self): + return [x for x in self.forward.param_level_pl_pairs if f"{x[0]}_{x[1]}" in self.variables] @property - def param_level_ml(self): - return [x for x in self.forward.param_level_ml if f"{x[0]}_{x[1]}" in self.variables] + def param_level_ml_pairs(self): + return [x for x in self.forward.param_level_ml_pairs if f"{x[0]}_{x[1]}" in self.variables] @property - def param_step(self): - return [x for x in self.forward.param_step if x[0] in self.variables] + def param_step_pairs(self): + return [x for x in self.forward.param_step_pairs if x[0] in self.variables] @property - def param_step_sfc(self): - return [x for x in self.forward.param_step_sfc if x[0] in self.variables] + def param_step_sfc_pairs(self): + return [x for x in self.forward.param_step_sfc_pairs if x[0] in self.variables] @property def variables_with_nans(self): @@ -226,48 +274,14 @@ def data_request(specific): return globals()[action](specific) -class Version_0_2_0(Metadata): +class Version_0_2_0(Metadata, Forward): def __init__(self, metadata): super().__init__(metadata) specific = metadata["dataset"]["specific"] - self.data_request = data_request(specific) - self.data_request.dump(0) - - @property - def variables(self): - return self.data_request.variables + self.forward = data_request(specific) @cached_property def area(self): - return self.rounded_area(self.data_request.area) - - @property - def grid(self): - return self.data_request.grid - - @cached_property - def variables_with_nans(self): - return self.data_request.variables_with_nans + return self.rounded_area(self.forward.area) ######################### - - @property - def param_sfc(self): - param_sfc = self.data_request.param_sfc - # Remove diagnostic variables - param_step_sfc = [p[0] for p in self.data_request.param_step_sfc] - return [p for p in param_sfc if p not in param_step_sfc] - - @property - def param_level_pl(self): - param_level_pl = self.data_request.param_level_pl - params = set([p[0] for p in param_level_pl]) - levels = set([p[1] for p in param_level_pl]) - return sorted(params), sorted(levels) - - @property - def param_level_ml(self): - param_level_ml = self.data_request.param_level_ml - params = set([p[0] for p in param_level_ml]) - levels = set([p[1] for p in param_level_ml]) - return sorted(params), sorted(levels) diff --git a/src/anemoi/inference/commands/checkpoint.py b/src/anemoi/inference/commands/checkpoint.py index ba2c8c4..423a074 100644 --- a/src/anemoi/inference/commands/checkpoint.py +++ b/src/anemoi/inference/commands/checkpoint.py @@ -19,27 +19,29 @@ class CheckpointCmd(Command): def add_arguments(self, command_parser): command_parser.add_argument("path", help="Path to the checkpoint.") + command_parser.add_argument("--dump", action="store_true", help="Print internal information") def run(self, args): c = Checkpoint(args.path) + + if args.dump: + c.dump() + return + print("area:", c.area) - print("computed_constants:", c.computed_constants) print("computed_constants_mask:", c.computed_constants_mask) - print("computed_forcings:", c.computed_forcings) + print("computed_constants:", c.computed_constants) print("computed_forcings_mask:", c.computed_forcings_mask) + print("computed_forcings:", c.computed_forcings) print("constant_data_from_input_mask:", c.constant_data_from_input_mask) - print("constants_from_input:", c.constants_from_input) print("constants_from_input_mask:", c.constants_from_input_mask) - print("data_request:", c.data_request) + print("constants_from_input:", c.constants_from_input) print("data_to_model:", c.data_to_model) print("diagnostic_output_mask:", c.diagnostic_output_mask) print("diagnostic_params:", c.diagnostic_params) - print("from_metadata:", c.from_metadata) print("grid:", c.grid) print("hour_steps:", c.hour_steps) - print("imputable variables", c.imputable_variables) - print("imputable variables:", c.imputable_variables) print("imputable_variables:", c.imputable_variables) print("index_to_variable:", c.index_to_variable) print("model_to_data:", c.model_to_data) @@ -55,12 +57,10 @@ def run(self, args): print("prognostic_input_mask:", c.prognostic_input_mask) print("prognostic_output_mask:", c.prognostic_output_mask) print("prognostic_params:", c.prognostic_params) - print("report_loading_error:", c.report_loading_error) - print("rounded_area:", c.rounded_area) print("select:", c.select) print("variable_to_index:", c.variable_to_index) + print("variables_with_nans:", c.variables_with_nans) print("variables:", c.variables) - print("variables_with_nans::", c.variables_with_nans) command = CheckpointCmd