diff --git a/stablepy/diffusers_vanilla/model.py b/stablepy/diffusers_vanilla/model.py index 2cd1253..f8746dd 100644 --- a/stablepy/diffusers_vanilla/model.py +++ b/stablepy/diffusers_vanilla/model.py @@ -1,5 +1,4 @@ -import gc -import time +import gc, time import numpy as np import PIL.Image from diffusers import ( @@ -13,11 +12,10 @@ StableDiffusionXLAdapterPipeline, T2IAdapter, StableDiffusionXLPipeline, + AutoPipelineForImage2Image ) -import json from huggingface_hub import hf_hub_download -import torch -import random +import torch, random, json from controlnet_aux import ( CannyDetector, ContentShuffleDetector, @@ -60,19 +58,17 @@ from .high_resolution import process_images_high_resolution from .style_prompt_config import styles_data, STYLE_NAMES, get_json_content, apply_style import os -from compel import Compel -from compel import ReturnedEmbeddingsType +from compel import Compel, ReturnedEmbeddingsType import ipywidgets as widgets, mediapy from IPython.display import display from PIL import Image -#from asdff.sd import AdCnPreloadPipe from typing import Union, Optional, List, Tuple, Dict, Any, Callable -import logging -from diffusers import AutoPipelineForImage2Image -import diffusers +import logging, diffusers, copy, warnings logging.getLogger("diffusers").setLevel(logging.ERROR) +#logging.getLogger("transformers").setLevel(logging.ERROR) diffusers.utils.logging.set_verbosity(40) -import copy +warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers") +warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers") # ===================================== # Utils preprocessor @@ -259,16 +255,16 @@ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: "LMS": (LMSDiscreteScheduler, {}), "LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}), "DDIM": (DDIMScheduler, {}), - "DEISMultistep": (DEISMultistepScheduler, {}), - "UniPCMultistep": (UniPCMultistepScheduler, {}), + "DEIS": (DEISMultistepScheduler, {}), + "UniPC": (UniPCMultistepScheduler, {}), "PNDM" : (PNDMScheduler, {}), - "LCM" : (LCMScheduler, {}), - "DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True}), "DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"euler_at_final": True}), "DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}), "DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}), + + "LCM" : (LCMScheduler, {}), } scheduler_names = list(SCHEDULER_CONFIG_MAP.keys()) @@ -433,6 +429,8 @@ def load_pipe( use_safetensors=True, add_watermarker=False, ) + self.base_model_id = base_model_id + self.class_name = class_name # Load VAE after loaded model if vae_model is None : @@ -452,6 +450,7 @@ def load_pipe( self.pipe.vae.to(self.type_model_precision) except: logger.warning(f"VAE: not in {self.type_model_precision}") + self.vae_model = vae_model # Define base scheduler self.default_scheduler = copy.deepcopy(self.pipe.scheduler) @@ -609,13 +608,6 @@ def load_controlnet_weight(self, task_name: str) -> None: self.pipe.controlnet = controlnet #self.task_name = task_name - def get_prompt(self, prompt: str, additional_prompt: str) -> str: - if not prompt: - prompt = additional_prompt - else: - prompt = f"{prompt}, {additional_prompt}" - return prompt - @torch.autocast("cuda") def run_pipe( self, @@ -1472,7 +1464,7 @@ def __call__( sampler (str, optional, defaults to "DPM++ 2M"): The sampler used for the generation process. Available samplers: DPM++ 2M, DPM++ 2M Karras, DPM++ 2M SDE, DPM++ 2M SDE Karras, DPM++ SDE, DPM++ SDE Karras, DPM2, DPM2 Karras, Euler, Euler a, Heun, LMS, LMS Karras, - DDIM, DEISMultistep, UniPCMultistep, LCM, DPM++ 2M Lu, DPM++ 2M Ef, DPM++ 2M SDE Lu and DPM++ 2M SDE Ef. + DDIM, DEIS, UniPC, DPM2 a, DPM2 a Karras, PNDM, LCM, DPM++ 2M Lu, DPM++ 2M Ef, DPM++ 2M SDE Lu and DPM++ 2M SDE Ef. syntax_weights (str, optional, defaults to "Classic"): Specifies the type of syntax weights used during generation. "Classic" is (word:weight), "Compel" is (word)weight lora_A (str, optional): @@ -1619,16 +1611,17 @@ def __call__( - image_mask - image_resolution - strength - - controlnet_conditioning_scale - - control_guidance_start - - control_guidance_end + for SD 1.5: + - controlnet_conditioning_scale + - control_guidance_start + - control_guidance_end Additional parameters that will be used in img2img: - image - image_resolution - strength - Additional parameters that will be used in ControlNet depending on the task: + Additional parameters that will be used in ControlNet for SD 1.5 depending on the task: - image - preprocessor_name - preprocess_resolution @@ -1643,7 +1636,7 @@ def __call__( - value_threshold - distance_threshold - Additional parameters that will be used in T2I adapter depending on the task: + Additional parameters that will be used in T2I adapter for SDXL depending on the task: - image - preprocess_resolution - image_resolution @@ -2022,7 +2015,7 @@ def __call__( if self.class_name == "StableDiffusionPipeline": # Base params pipe sd pipe_params_config = { - "prompt": None, # prompt, #self.get_prompt(prompt, additional_prompt), + "prompt": None, # prompt, "negative_prompt": None, # negative_prompt, "prompt_embeds": prompt_emb, "negative_prompt_embeds": negative_prompt_emb,