diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 9410aca..5a1754d 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -12,6 +12,7 @@ import logging from collections import defaultdict from functools import cached_property +from pathlib import Path from anemoi.utils.checkpoints import load_metadata from earthkit.data.utils.dates import to_datetime @@ -25,10 +26,21 @@ def _download_huggingfacehub(huggingface_config): """Download model from huggingface""" try: from huggingface_hub import hf_hub_download + from huggingface_hub import snapshot_download except ImportError as e: raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e - config_path = hf_hub_download(**huggingface_config) + if "filename" in huggingface_config: + config_path = hf_hub_download(**huggingface_config) + else: + repo_path = Path(snapshot_download(**huggingface_config)) + ckpt_files = list(repo_path.glob("*.ckpt")) + if len(ckpt_files) == 1: + return str(ckpt_files[0]) + else: + ValueError( + f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`." + ) return config_path diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index 58d9096..29a5791 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -234,7 +234,7 @@ def variables_metadata(self): result = self._metadata.dataset.variables_metadata self._legacy_check_variables_metadata(result) except AttributeError: - return self._legacy_variables_metadata() + result = self._legacy_variables_metadata() if "constant_fields" in self._metadata.dataset: for name in self._metadata.dataset.constant_fields: diff --git a/src/anemoi/inference/outputs/gribfile.py b/src/anemoi/inference/outputs/gribfile.py index d6d9bf9..9fc9995 100644 --- a/src/anemoi/inference/outputs/gribfile.py +++ b/src/anemoi/inference/outputs/gribfile.py @@ -24,6 +24,8 @@ class ArchiveCollector: + """Collects archive requests""" + UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"} def __init__(self) -> None: diff --git a/src/anemoi/inference/outputs/raw.py b/src/anemoi/inference/outputs/raw.py index 7a4ed06..a46ec52 100644 --- a/src/anemoi/inference/outputs/raw.py +++ b/src/anemoi/inference/outputs/raw.py @@ -48,6 +48,8 @@ def write_state(self, state): date = state["date"].strftime(self.strftime) fn_state = f"{self.path}/{self.template.format(date=date)}" restate = {f"field_{key}": val for key, val in state["fields"].items()} - for key in ["date", "longitudes", "latitudes"]: + for key in ["date"]: restate[key] = np.array(state[key], dtype=str) + for key in ["latitudes", "longitudes"]: + restate[key] = np.array(state[key]) np.savez_compressed(fn_state, **restate)