Skip to content

Commit

Permalink
fix bugg in hf_utils model downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Borg93 committed Apr 22, 2024
1 parent 7a4fdc5 commit 0a8d21a
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/htrflow_core/models/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

# TODO: add pytest

logger = logging.getLogger(__name__)


class HFBaseDownloader:
META_MODEL_TYPE = "model"
Expand Down Expand Up @@ -77,9 +79,9 @@ class MMLabsDownloader(HFBaseDownloader):
def from_pretrained(
cls,
model_id: str,
config_id: Optional[str] = None,
cache_dir: str = "./.cache",
hf_token: Optional[str] = None,
config_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Download and load config and model from Openmmlabs using the HuggingFace Hub."""

Expand All @@ -88,6 +90,7 @@ def from_pretrained(
existing_model = downloader._mmlab_try_load_from_local_files(model_id, config_id)

if existing_model:
logging.info(f"Loaded existing model from '{existing_model}'")
return existing_model

repo_files = downloader.list_files_from_repo(model_id)
Expand All @@ -105,14 +108,16 @@ def from_pretrained(
dictionary_path = downloader.wrapper_hf_hub_download(model_id, dict_file)
downloader._fix_mmlab_dict_file(config_path, dictionary_path)

logging.info(f"Downloaded model '{model_id}' from HF and loaded it from folder: '{cache_dir}'")

return model_path, config_path

def _mmlab_try_load_from_local_files(self, model_id, config_id) -> Optional[Tuple[str, str]]:
model_path = Path(model_id)
config_path = Path(config_id)
if model_path.exists() and model_path.suffix in self.MMLABS_SUPPORTED_MODEL_TYPES:
if config_path.exists() and config_path.suffix == self.PY_EXTENSION:
return model_id, config_id
return str(model_id), str(config_id)
elif config_path.exists() and config_path.suffix != self.PY_EXTENSION:
raise ValueError(f"Please provide config of type: {self.MMLABS_CONFIG_FILE}")
return None
Expand All @@ -135,12 +140,15 @@ def from_pretrained(cls, model_id: str, cache_dir: str = "./.cache", hf_token: O
downloader = cls(cache_dir=cache_dir, hf_token=hf_token)
existing_model = downloader._ultralytics_try_load_from_local_files(model_id)
if existing_model:
return f"Loaded existing model from {existing_model}"
logging.info(f"Loaded existing model from '{existing_model}'")
return existing_model

repo_files = downloader.list_files_from_repo(model_id)
return downloader._download_file_from_hf(
cache_model_path = downloader._download_file_from_hf(
model_id, cls.ULTRALYTICS_SUPPORTED_MODEL_TYPES, cls.META_MODEL_TYPE, repo_files
)
logging.info(f"Downloaded model '{model_id}' from HF and loaded it from folder: '{cache_dir}'")
return cache_model_path

def _ultralytics_try_load_from_local_files(self, model_id: str) -> Optional[str]:
"""Check for an existing local file for the model."""
Expand Down

0 comments on commit 0a8d21a

Please sign in to comment.