From b4b87f64495b2be94dd3f76c7588183e26fdca2a Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Sat, 14 Dec 2024 14:52:09 +0000 Subject: [PATCH 1/2] Hotfix: huggingface loading issues (#81) * Hotfix: huggingface loading --- src/anemoi/inference/checkpoint.py | 25 ++++++++++++++----------- src/anemoi/inference/config.py | 2 +- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 5a1754d8..fea51f3d 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -30,6 +30,9 @@ def _download_huggingfacehub(huggingface_config): except ImportError as e: raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e + if isinstance(huggingface_config, str): + huggingface_config = {"repo_id": huggingface_config} + if "filename" in huggingface_config: config_path = hf_hub_download(**huggingface_config) else: @@ -48,7 +51,7 @@ class Checkpoint: """Represents an inference checkpoint.""" def __init__(self, path, *, patch_metadata=None): - self.path = path + self._path = path self.patch_metadata = patch_metadata def __repr__(self): @@ -59,17 +62,17 @@ def path(self): import json try: - self._model = json.loads(self._model) - except TypeError: - pass - - if isinstance(self._model, str): - return self._model - elif isinstance(self._model, dict): - if "huggingface" in self._model: - return _download_huggingfacehub(self._model["huggingface"]) + path = json.loads(self._path) + except Exception: + path = self._path + + if isinstance(path, (Path, str)): + return path + elif isinstance(path, dict): + if "huggingface" in path: + return _download_huggingfacehub(path["huggingface"]) pass - raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict") + raise TypeError(f"Cannot parse model path: {path}. It must be a path or dict") @cached_property def _metadata(self): diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 738d2f67..f62a9c8a 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -29,7 +29,7 @@ class Config: description: str | None = None - checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]] + checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str] """A path to an Anemoi checkpoint file.""" date: str | int | datetime.datetime | None = None From 5b9f7ab8471569aa098ed025b0cd738c09896b22 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Sat, 14 Dec 2024 14:54:10 +0000 Subject: [PATCH 2/2] Hotfix: huggingface loading issues (#81) (#82) * Hotfix: huggingface loading --- src/anemoi/inference/checkpoint.py | 25 ++++++++++++++----------- src/anemoi/inference/config.py | 2 +- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 5a1754d8..fea51f3d 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -30,6 +30,9 @@ def _download_huggingfacehub(huggingface_config): except ImportError as e: raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e + if isinstance(huggingface_config, str): + huggingface_config = {"repo_id": huggingface_config} + if "filename" in huggingface_config: config_path = hf_hub_download(**huggingface_config) else: @@ -48,7 +51,7 @@ class Checkpoint: """Represents an inference checkpoint.""" def __init__(self, path, *, patch_metadata=None): - self.path = path + self._path = path self.patch_metadata = patch_metadata def __repr__(self): @@ -59,17 +62,17 @@ def path(self): import json try: - self._model = json.loads(self._model) - except TypeError: - pass - - if isinstance(self._model, str): - return self._model - elif isinstance(self._model, dict): - if "huggingface" in self._model: - return _download_huggingfacehub(self._model["huggingface"]) + path = json.loads(self._path) + except Exception: + path = self._path + + if isinstance(path, (Path, str)): + return path + elif isinstance(path, dict): + if "huggingface" in path: + return _download_huggingfacehub(path["huggingface"]) pass - raise TypeError(f"Cannot parse model path: {self._model}. It must be a path or dict") + raise TypeError(f"Cannot parse model path: {path}. It must be a path or dict") @cached_property def _metadata(self): diff --git a/src/anemoi/inference/config.py b/src/anemoi/inference/config.py index 738d2f67..f62a9c8a 100644 --- a/src/anemoi/inference/config.py +++ b/src/anemoi/inference/config.py @@ -29,7 +29,7 @@ class Config: description: str | None = None - checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any]] + checkpoint: str | Dict[Literal["huggingface"], Dict[str, Any] | str] """A path to an Anemoi checkpoint file.""" date: str | int | datetime.datetime | None = None