You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This would allow the checkpoint_model_type function in utils.py to be something like:
# load the required modules
from stablepy import huggingface_guess
from picklescan.scanner import scan_file_path as legacy_scan
from safetensors.torch import load_file as safe_load
from torch import load as legacy_load
# read the checkpoint
if path.lower().endswith(".safetensors"):
state_dict = safe_load(checkpoint_path, device="cpu")
repo_name = huggingface_guess.guess_repo_name(state_dict)
else:
scan_result = legacy_scan(checkpoint_path)
if scan_result is 0:
state_dict = legacy_load(checkpoint_path, device="cpu")
repo_name = huggingface_guess.guess_repo_name(state_dict)
elif scan_result is 2:
repo_name is "security_error"
else:
repo_name = "security_blocked"
# match the repo_name to the preexisting definitions
if repo_name is "stable-diffusion-v1-5/stable-diffusion-v1-5":
model_type = "sd1.5"
elif repo_name is "stabilityai/stable-diffusion-2-1":
model_type = "sd2.1"
elif repo_name is "stabilityai/stable-diffusion-xl-base-1.0":
model_type = "sdxl"
elif repo_name is "stabilityai/stable-diffusion-xl-refiner-1.0":
model_type = "refiner"
elif repo_name is "black-forest-labs/FLUX.1-dev":
model_type = "flux-dev"
elif repo_name is "black-forest-labs/FLUX.1-schnell":
model_type = "flux-schnell"
elif repo_name is "security_error":
logger.debug(str(e))
logger.info("Error reading checkpoint: unable to complete scan for malicious code")
elif repo_name is "security_blocked":
logger.debug(str(e))
logger.info("Error reading checkpoint: potentially malicious code detected")
else:
logger.debug(str(e))
logger.info("Error reading checkpoint: unsupported model type", model_type)
# unload the checkpoint
if state_dict:
del state_dict
return model_type
The code used to detect a checkpoint is minimalistic, and the package would be better served by including huggingface_guess as a submodule.
One minor change would be helpful on model_list.py#L100 to account for the new SD15 base repository:
huggingface_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5"
This would allow the checkpoint_model_type function in utils.py to be something like:
Assuming that #14 is implemented.
The text was updated successfully, but these errors were encountered: