From cf26cbd2c3f25886fa3576ef54828268e59eb52c Mon Sep 17 00:00:00 2001 From: Markus Schuettler Date: Fri, 22 Nov 2024 11:36:04 +0100 Subject: [PATCH] wip gguf model downloader --- WebUI/src/views/Create.vue | 17 +++++++++++++++++ service/model_downloader.py | 14 ++++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/WebUI/src/views/Create.vue b/WebUI/src/views/Create.vue index 0ff1f4de..0998c8e9 100644 --- a/WebUI/src/views/Create.vue +++ b/WebUI/src/views/Create.vue @@ -103,10 +103,27 @@ const emits = defineEmits<{ }>(); async function generateImage() { + await ensureModelsAreAvailable(); reset(); await imageGeneration.generate(); } +async function ensureModelsAreAvailable() { + return new Promise(async (resolve, reject) => { + const downloadList = await imageGeneration.getMissingModels(); + if (downloadList.length > 0) { + emits( + "showDownloadModelConfirm", + downloadList, + resolve, + reject + ); + } else { + resolve && resolve(); + } + }); +} + function postImageToEnhance() { emits("postImageToEnhance", imageGeneration.imageUrls[imageGeneration.previewIdx]); } diff --git a/service/model_downloader.py b/service/model_downloader.py index 538da864..629358aa 100644 --- a/service/model_downloader.py +++ b/service/model_downloader.py @@ -85,10 +85,12 @@ def __init__(self, hf_token=None) -> None: def is_gated(self, repo_id: str): try: - info = model_info(repo_id) + # strip gguf from repo_id + namespace_and_repo = ("/").join(repo_id.split("/")[0:2]) + info = model_info(namespace_and_repo) return info.gated except Exception as ex: - print(f"Error while trying to determine whether {repo_id} is gated: {ex}") + print(f"Error while trying to determine whether {namespace_and_repo} is gated: {ex}") return False def download(self, repo_id: str, model_type: int, thread_count: int = 4): @@ -169,6 +171,7 @@ def enum_file_list( self, file_list: List, enum_path: str, model_type: int, is_root=True ): list = self.fs.ls(enum_path, detail=True) + print('listing files to get model size', list) if model_type == 1 and enum_path == self.repo_id + "/unet": list = self.enum_sd_unet(list) for item in list: @@ -205,12 +208,15 @@ def enum_file_list( continue self.total_size += size - relative_path = path.relpath(name, self.repo_id) + namespace_and_repo = ("/").join(self.repo_id.split("/")[0:2]) + relative_path = path.relpath(name, namespace_and_repo) subfolder = path.dirname(relative_path).replace("\\", "/") filename = path.basename(relative_path) + print('got size of ', relative_path, subfolder, filename, size) url = hf_hub_url( - repo_id=self.repo_id, subfolder=subfolder, filename=filename + repo_id=namespace_and_repo, subfolder=subfolder, filename=filename ) + print('adding file to download list', relative_path, size, url) file_list.append(HFFileItem(relative_path, size, url)) def enum_sd_unet(self, file_list: List[str | Dict[str, Any]]):