Skip to content

Commit

Permalink
Update to download only relevant files and not the whole model repo
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Apr 5, 2024
1 parent 0e16e99 commit b85b069
Showing 1 changed file with 48 additions and 11 deletions.
59 changes: 48 additions & 11 deletions src/sparseml/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import PaddingStrategy

from huggingface_hub import HUGGINGFACE_CO_URL_HOME, hf_hub_download, snapshot_download
from huggingface_hub import HUGGINGFACE_CO_URL_HOME, HfFileSystem, hf_hub_download
from sparseml.export.helpers import ONNX_MODEL_NAME
from sparseml.utils import download_zoo_training_dir
from sparseml.utils.fsdp.context import main_process_first_context
Expand Down Expand Up @@ -96,6 +96,7 @@ class TaskNames(Enum):
"special_tokens_map.json",
"tokenizer_config.json",
}
RELEVANT_HF_SUFFIXES = ["json", "md", "bin", "safetensors", "yaml", "yml"]


def remove_past_key_value_support_from_config(config: AutoConfig) -> AutoConfig:
Expand Down Expand Up @@ -561,25 +562,61 @@ def fetch_recipe_path(target: str):

def download_repo_from_huggingface_hub(repo_id, **kwargs):
"""
Download a model repo from the Hugging Face Hub
using the huggingface_hub.snapshot_download function
Download relevant model files from the Hugging Face Hub
using the huggingface_hub.hf_hub_download function
Note(s):
- Does not download the entire repo, only the relevant files
for the model, such as the model weights, tokenizer files, etc.
- Does not re-download files that already exist locally, unless
the force_download flag is set to True
:pre-condition: the repo_id must be a valid Hugging Face Hub repo id
:param repo_id: the repo id to download
:param kwargs: additional keyword arguments to pass to snapshot_download
:param kwargs: additional keyword arguments to pass to hf_hub_download
"""
hub_kwargs_names = [
hf_filesystem = HfFileSystem()
files = hf_filesystem.ls(repo_id)

if not files:
raise ValueError(f"Could not find any files in HF repo {repo_id}")

# All file(s) from hf_filesystem have "name" key
# Extract the file names from the files
relevant_file_names = (
Path(file["name"]).name
for file in files
if any(file["name"].endswith(suffix) for suffix in RELEVANT_HF_SUFFIXES)
)

hub_kwargs_names = (
"subfolder",
"repo_type",
"revision",
"library_name",
"library_version",
"cache_dir",
"local_dir",
"local_dir_use_symlinks",
"user_agent",
"force_download",
"local_files_only",
"force_filename",
"proxies",
"etag_timeout",
"resume_download",
"revision",
"subfolder",
"use_auth_token",
"token",
]
"local_files_only",
"headers",
"legacy_cache_layout",
"endpoint",
)
hub_kwargs = {name: kwargs[name] for name in hub_kwargs_names if name in kwargs}
return snapshot_download(repo_id, **hub_kwargs)

for file_name in relevant_file_names:
last_file = hf_hub_download(repo_id=repo_id, filename=file_name, **hub_kwargs)

# parent directory of the last file is the model directory
return str(Path(last_file).parent.resolve().absolute())


def download_model_directory(pretrained_model_name_or_path: str, **kwargs):
Expand Down

0 comments on commit b85b069

Please sign in to comment.