Skip to content

Commit

Permalink
feat: store model download paths in "path" key (#274)
Browse files Browse the repository at this point in the history
Models downloaded wrt `models_to_fetch` dict have the output path of the
downloaded model files in the "path" key for each model.

Signed-off-by: Anupam Kumar <[email protected]>
Co-authored-by: Alexander Piskun <[email protected]>
  • Loading branch information
kyteinsky and bigcat88 authored Jul 18, 2024
1 parent 1bcd2b9 commit 1926973
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions nc_py_api/ex_app/integration_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,26 @@ def fetch_models_task(nc: NextcloudApp, models: dict[str, dict], progress_init_s
percent_for_each = min(int((100 - progress_init_start_value) / len(models)), 99)
for model in models:
if model.startswith(("http://", "https://")):
__fetch_model_as_file(current_progress, percent_for_each, nc, model, models[model])
models[model]["path"] = __fetch_model_as_file(
current_progress, percent_for_each, nc, model, models[model]
)
else:
__fetch_model_as_snapshot(current_progress, percent_for_each, nc, model, models[model])
models[model]["path"] = __fetch_model_as_snapshot(
current_progress, percent_for_each, nc, model, models[model]
)
current_progress += percent_for_each
nc.set_init_status(100)


def __fetch_model_as_file(
current_progress: int, progress_for_task: int, nc: NextcloudApp, model_path: str, download_options: dict
) -> None:
) -> str | None:
result_path = download_options.pop("save_path", urlparse(model_path).path.split("/")[-1])
try:
with httpx.stream("GET", model_path, follow_redirects=True) as response:
if not response.is_success:
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' returned {response.status_code} status.")
return
return None
downloaded_size = 0
linked_etag = ""
for each_history in response.history:
Expand All @@ -163,7 +167,7 @@ def __fetch_model_as_file(
sha256_hash.update(byte_block)
if f'"{sha256_hash.hexdigest()}"' == linked_etag:
nc.set_init_status(min(current_progress + progress_for_task, 99))
return
return None

with builtins.open(result_path, "wb") as file:
last_progress = current_progress
Expand All @@ -174,13 +178,17 @@ def __fetch_model_as_file(
if new_progress != last_progress:
nc.set_init_status(new_progress)
last_progress = new_progress

return result_path
except Exception as e: # noqa pylint: disable=broad-exception-caught
nc.log(LogLvl.ERROR, f"Downloading of '{model_path}' raised an exception: {e}")

return None


def __fetch_model_as_snapshot(
current_progress: int, progress_for_task, nc: NextcloudApp, mode_name: str, download_options: dict
) -> None:
) -> str:
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401

Expand All @@ -191,7 +199,9 @@ def display(self, msg=None, pos=None):

workers = download_options.pop("max_workers", 2)
cache = download_options.pop("cache_dir", persistent_storage())
snapshot_download(mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache)
return snapshot_download(
mode_name, tqdm_class=TqdmProgress, **download_options, max_workers=workers, cache_dir=cache
)


def __nc_app(request: HTTPConnection) -> dict:
Expand Down

0 comments on commit 1926973

Please sign in to comment.