Skip to content

Commit

Permalink
Tidy up code
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 29, 2024
1 parent 621b17a commit 3cf43b7
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 115 deletions.
37 changes: 1 addition & 36 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,42 +305,7 @@ def multi_step(self):

@cached_property
def imputable_variables(self):
result = []

def from_input_imputer(config):
for k, v in config.items():
if not isinstance(v, list):
v = [v]
yield from v

def from_constant_imputer(config):
yield from config.keys()

def empty(config):
return []

IMPUTERS = {
"aifs.preprocessing.imputer.InputImputer": from_input_imputer,
"aifs.preprocessing.imputer.ConstantImputer": from_constant_imputer,
}

for k, v in self._config_data.get("processors", {}).items():
target = v.get("_target_")
if target is None:
continue

if "imputer" in target.lower() and target not in IMPUTERS:
LOG.warning("Unknown imputer %s, ignoring", target)
continue

source = IMPUTERS.get(target, empty)
result.extend(source(v.get("config", {})))

result = sorted(set(result))
if result:
LOG.info("Imputable variables %s", result)

return result
return self.variables_with_nans

def rounded_area(self, area):
surface = (area[0] - area[2]) * (area[3] - area[1]) / 180 / 360
Expand Down
152 changes: 83 additions & 69 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,62 @@ def variables(self):
def __repr__(self) -> str:
return self.__class__.__name__

def mars_request(self):

def _as_list(v):
if isinstance(v, list):
return v
return [v]

def _as_string(r):
r = {k: "/".join([str(x) for x in _as_list(v)]) for k, v in r.items() if v}
return ",".join([f"{k}={v}" for k, v in r.items()])

r = dict(grid=self.grid, area=self.area)
yield _as_string(r)

r = dict(param=_as_list(self.param_sfc))
yield _as_string(r)

param, pl = self.param_level_pl
if param:
r = dict(param=param, level=pl)
yield _as_string(r)

param, ml = self.param_level_ml
if param:
r = dict(param=param, level=ml)
yield _as_string(r)

nans = self.variables_with_nans
if nans:
r = dict(with_nans=nans)
yield _as_string(r)

def dump_content(self, indent=0):
print()
print(" " * indent, "-", self)
for n in self.mars_request():
print(" " * indent, " ", n)

@property
def param_sfc(self):
param_sfc = self.forward.param_sfc
param_step_sfc = [p[0] for p in self.forward.param_step_sfc_pairs]
return [p for p in param_sfc if p not in param_step_sfc]

@property
def param_level_pl(self):
params = set([p[0] for p in self.param_level_pl_pairs])
levels = set([p[1] for p in self.param_level_pl_pairs])
return sorted(params), sorted(levels)

@property
def param_level_ml(self):
params = set([p[0] for p in self.param_level_ml_pairs])
levels = set([p[1] for p in self.param_level_ml_pairs])
return sorted(params), sorted(levels)


class ZarrRequest(DataRequest):
def __init__(self, metadata):
Expand All @@ -44,24 +100,23 @@ def param_sfc(self):
return self.request["param_level"].get("sfc", [])

@property
def param_level_pl(self):
def param_level_pl_pairs(self):
return self.request["param_level"].get("pl", [])

@property
def param_level_ml(self):
def param_level_ml_pairs(self):
return self.request["param_level"].get("ml", [])

@property
def param_step_sfc(self):
def param_step_sfc_pairs(self):
return self.request["param_step"].get("sfc", [])

@property
def variables_with_nans(self):
return sorted(self.attributes.get("variables_with_nans", []))

def dump(self, indent):
print(" " * indent, self)
print(" " * indent, self.request)
def dump(self, indent=0):
self.dump_content(indent)


class Forward(DataRequest):
Expand All @@ -72,8 +127,8 @@ def forward(self):
def __getattr__(self, name):
return getattr(self.forward, name)

def dump(self, indent):
print(" " * indent, self)
def dump(self, indent=0):
self.dump_content(indent)
self.forward.dump(indent + 2)


Expand All @@ -88,13 +143,10 @@ class StatisticsRequest(Forward):

class RenameRequest(Forward):

@property
def variables(self):
raise NotImplementedError()

@property
def variables_with_nans(self):
raise NotImplementedError()
return sorted(self.forward.variables_with_nans)


class MultiRequest(Forward):
Expand All @@ -106,8 +158,8 @@ def __init__(self, metadata):
def forward(self):
return self.datasets[0]

def dump(self, indent):
print(" " * indent, self)
def dump(self, indent=0):
self.dump_content(indent)
for dataset in self.datasets:
dataset.dump(indent + 2)

Expand All @@ -123,36 +175,32 @@ def param_sfc(self):
return result

@property
def param_level_pl(self):
def param_level_pl_pairs(self):
result = []
for dataset in self.datasets:
for param in dataset.param_level_pl:
for param in dataset.param_level_pl_pairs:
if param not in result:
result.append(param)
return result

@property
def param_level_ml(self):
result = []
for dataset in self.datasets:
for param in dataset.param_level_ml:
for dataset in self.datasets_pairs:
for param in dataset.param_level_ml_pairs:
if param not in result:
result.append(param)
return result

@property
def param_step_sfc(self):
def param_step_sfc_pairs(self):
result = []
for dataset in self.datasets:
for param in dataset.param_step_sfc:
for param in dataset.param_step_sfc_pairs:
if param not in result:
result.append(param)
return result

@property
def variables(self):
raise NotImplementedError()

@property
def variables_with_nans(self):
result = set()
Expand Down Expand Up @@ -184,20 +232,20 @@ def param_sfc(self):
return [x for x in self.forward.param_sfc if x in self.variables]

@property
def param_level_pl(self):
return [x for x in self.forward.param_level_pl if f"{x[0]}_{x[1]}" in self.variables]
def param_level_pl_pairs(self):
return [x for x in self.forward.param_level_pl_pairs if f"{x[0]}_{x[1]}" in self.variables]

@property
def param_level_ml(self):
return [x for x in self.forward.param_level_ml if f"{x[0]}_{x[1]}" in self.variables]
def param_level_ml_pairs(self):
return [x for x in self.forward.param_level_ml_pairs if f"{x[0]}_{x[1]}" in self.variables]

@property
def param_step(self):
return [x for x in self.forward.param_step if x[0] in self.variables]
def param_step_pairs(self):
return [x for x in self.forward.param_step_pairs if x[0] in self.variables]

@property
def param_step_sfc(self):
return [x for x in self.forward.param_step_sfc if x[0] in self.variables]
def param_step_sfc_pairs(self):
return [x for x in self.forward.param_step_sfc_pairs if x[0] in self.variables]

@property
def variables_with_nans(self):
Expand Down Expand Up @@ -226,48 +274,14 @@ def data_request(specific):
return globals()[action](specific)


class Version_0_2_0(Metadata):
class Version_0_2_0(Metadata, Forward):
def __init__(self, metadata):
super().__init__(metadata)
specific = metadata["dataset"]["specific"]
self.data_request = data_request(specific)
self.data_request.dump(0)

@property
def variables(self):
return self.data_request.variables
self.forward = data_request(specific)

@cached_property
def area(self):
return self.rounded_area(self.data_request.area)

@property
def grid(self):
return self.data_request.grid

@cached_property
def variables_with_nans(self):
return self.data_request.variables_with_nans
return self.rounded_area(self.forward.area)

#########################

@property
def param_sfc(self):
param_sfc = self.data_request.param_sfc
# Remove diagnostic variables
param_step_sfc = [p[0] for p in self.data_request.param_step_sfc]
return [p for p in param_sfc if p not in param_step_sfc]

@property
def param_level_pl(self):
param_level_pl = self.data_request.param_level_pl
params = set([p[0] for p in param_level_pl])
levels = set([p[1] for p in param_level_pl])
return sorted(params), sorted(levels)

@property
def param_level_ml(self):
param_level_ml = self.data_request.param_level_ml
params = set([p[0] for p in param_level_ml])
levels = set([p[1] for p in param_level_ml])
return sorted(params), sorted(levels)
20 changes: 10 additions & 10 deletions src/anemoi/inference/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,29 @@ class CheckpointCmd(Command):

def add_arguments(self, command_parser):
command_parser.add_argument("path", help="Path to the checkpoint.")
command_parser.add_argument("--dump", action="store_true", help="Print internal information")

def run(self, args):

c = Checkpoint(args.path)

if args.dump:
c.dump()
return

print("area:", c.area)
print("computed_constants:", c.computed_constants)
print("computed_constants_mask:", c.computed_constants_mask)
print("computed_forcings:", c.computed_forcings)
print("computed_constants:", c.computed_constants)
print("computed_forcings_mask:", c.computed_forcings_mask)
print("computed_forcings:", c.computed_forcings)
print("constant_data_from_input_mask:", c.constant_data_from_input_mask)
print("constants_from_input:", c.constants_from_input)
print("constants_from_input_mask:", c.constants_from_input_mask)
print("data_request:", c.data_request)
print("constants_from_input:", c.constants_from_input)
print("data_to_model:", c.data_to_model)
print("diagnostic_output_mask:", c.diagnostic_output_mask)
print("diagnostic_params:", c.diagnostic_params)
print("from_metadata:", c.from_metadata)
print("grid:", c.grid)
print("hour_steps:", c.hour_steps)
print("imputable variables", c.imputable_variables)
print("imputable variables:", c.imputable_variables)
print("imputable_variables:", c.imputable_variables)
print("index_to_variable:", c.index_to_variable)
print("model_to_data:", c.model_to_data)
Expand All @@ -55,12 +57,10 @@ def run(self, args):
print("prognostic_input_mask:", c.prognostic_input_mask)
print("prognostic_output_mask:", c.prognostic_output_mask)
print("prognostic_params:", c.prognostic_params)
print("report_loading_error:", c.report_loading_error)
print("rounded_area:", c.rounded_area)
print("select:", c.select)
print("variable_to_index:", c.variable_to_index)
print("variables_with_nans:", c.variables_with_nans)
print("variables:", c.variables)
print("variables_with_nans::", c.variables_with_nans)


command = CheckpointCmd

0 comments on commit 3cf43b7

Please sign in to comment.