Skip to content

Commit

Permalink
wip gguf model downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
mschuettlerTNG committed Nov 22, 2024
1 parent f18c439 commit cf26cbd
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
17 changes: 17 additions & 0 deletions WebUI/src/views/Create.vue
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,27 @@ const emits = defineEmits<{
}>();
async function generateImage() {
await ensureModelsAreAvailable();
reset();
await imageGeneration.generate();
}
async function ensureModelsAreAvailable() {
return new Promise<void>(async (resolve, reject) => {
const downloadList = await imageGeneration.getMissingModels();
if (downloadList.length > 0) {
emits(
"showDownloadModelConfirm",
downloadList,
resolve,
reject
);
} else {
resolve && resolve();
}
});
}
function postImageToEnhance() {
emits("postImageToEnhance", imageGeneration.imageUrls[imageGeneration.previewIdx]);
}
Expand Down
14 changes: 10 additions & 4 deletions service/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ def __init__(self, hf_token=None) -> None:

def is_gated(self, repo_id: str):
try:
info = model_info(repo_id)
# strip gguf from repo_id
namespace_and_repo = ("/").join(repo_id.split("/")[0:2])
info = model_info(namespace_and_repo)
return info.gated
except Exception as ex:
print(f"Error while trying to determine whether {repo_id} is gated: {ex}")
print(f"Error while trying to determine whether {namespace_and_repo} is gated: {ex}")
return False

def download(self, repo_id: str, model_type: int, thread_count: int = 4):
Expand Down Expand Up @@ -169,6 +171,7 @@ def enum_file_list(
self, file_list: List, enum_path: str, model_type: int, is_root=True
):
list = self.fs.ls(enum_path, detail=True)
print('listing files to get model size', list)
if model_type == 1 and enum_path == self.repo_id + "/unet":
list = self.enum_sd_unet(list)
for item in list:
Expand Down Expand Up @@ -205,12 +208,15 @@ def enum_file_list(
continue

self.total_size += size
relative_path = path.relpath(name, self.repo_id)
namespace_and_repo = ("/").join(self.repo_id.split("/")[0:2])
relative_path = path.relpath(name, namespace_and_repo)
subfolder = path.dirname(relative_path).replace("\\", "/")
filename = path.basename(relative_path)
print('got size of ', relative_path, subfolder, filename, size)
url = hf_hub_url(
repo_id=self.repo_id, subfolder=subfolder, filename=filename
repo_id=namespace_and_repo, subfolder=subfolder, filename=filename
)
print('adding file to download list', relative_path, size, url)
file_list.append(HFFileItem(relative_path, size, url))

def enum_sd_unet(self, file_list: List[str | Dict[str, Any]]):
Expand Down

0 comments on commit cf26cbd

Please sign in to comment.