Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Replace custom detection code with huggingface_guess #15

Open
iwr-redmond opened this issue Dec 6, 2024 · 0 comments
Open
Labels
enhancement New feature or request

Comments

@iwr-redmond
Copy link

iwr-redmond commented Dec 6, 2024

The code used to detect a checkpoint is minimalistic, and the package would be better served by including huggingface_guess as a submodule.

One minor change would be helpful on model_list.py#L100 to account for the new SD15 base repository:

huggingface_repo = "stable-diffusion-v1-5/stable-diffusion-v1-5"

This would allow the checkpoint_model_type function in utils.py to be something like:

# load the required modules
from stablepy import huggingface_guess
from picklescan.scanner import scan_file_path as legacy_scan
from safetensors.torch import load_file as safe_load
from torch import load as legacy_load

# read the checkpoint
if path.lower().endswith(".safetensors"):
        state_dict = safe_load(checkpoint_path, device="cpu")
        repo_name = huggingface_guess.guess_repo_name(state_dict)
    else:
        scan_result = legacy_scan(checkpoint_path)
        if scan_result is 0:
            state_dict = legacy_load(checkpoint_path, device="cpu")
            repo_name = huggingface_guess.guess_repo_name(state_dict)
        elif scan_result is 2:
            repo_name is "security_error"
        else:
            repo_name = "security_blocked"

# match the repo_name to the preexisting definitions
if repo_name is "stable-diffusion-v1-5/stable-diffusion-v1-5":
     model_type = "sd1.5"
elif repo_name is "stabilityai/stable-diffusion-2-1":
    model_type = "sd2.1"
elif repo_name is "stabilityai/stable-diffusion-xl-base-1.0":
    model_type = "sdxl"
elif repo_name is "stabilityai/stable-diffusion-xl-refiner-1.0":
    model_type = "refiner"
elif repo_name is "black-forest-labs/FLUX.1-dev":
    model_type = "flux-dev"
elif repo_name is "black-forest-labs/FLUX.1-schnell":
    model_type = "flux-schnell"
elif repo_name is "security_error":
    logger.debug(str(e))
    logger.info("Error reading checkpoint: unable to complete scan for malicious code")
elif repo_name is "security_blocked":
    logger.debug(str(e))
    logger.info("Error reading checkpoint: potentially malicious code detected")
else:
    logger.debug(str(e))
    logger.info("Error reading checkpoint: unsupported model type", model_type)

# unload the checkpoint
if state_dict:
    del state_dict

return model_type

Assuming that #14 is implemented.

@R3gm R3gm added the enhancement New feature or request label Dec 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants