diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 3e52130..fea51f3 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -30,6 +30,9 @@ def _download_huggingfacehub(huggingface_config): except ImportError as e: raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e + if isinstance(huggingface_config, str): + huggingface_config = {"repo_id": huggingface_config} + if "filename" in huggingface_config: config_path = hf_hub_download(**huggingface_config) else: diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 738d2f6..f62a9c8 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -29,7 +29,7 @@ class Config: description: str | None = None - checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]] + checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str] """A path to an Anemoi checkpoint file.""" date: str | int | datetime.datetime | None = None