Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
R3gm authored Nov 21, 2023
1 parent 40ac96e commit ab694be
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions stablepy/diffusers_vanilla/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import time
import gc, time
import numpy as np
import PIL.Image
from diffusers import (
Expand All @@ -13,11 +12,10 @@
StableDiffusionXLAdapterPipeline,
T2IAdapter,
StableDiffusionXLPipeline,
AutoPipelineForImage2Image
)
import json
from huggingface_hub import hf_hub_download
import torch
import random
import torch, random, json
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
Expand Down Expand Up @@ -60,19 +58,17 @@
from .high_resolution import process_images_high_resolution
from .style_prompt_config import styles_data, STYLE_NAMES, get_json_content, apply_style
import os
from compel import Compel
from compel import ReturnedEmbeddingsType
from compel import Compel, ReturnedEmbeddingsType
import ipywidgets as widgets, mediapy
from IPython.display import display
from PIL import Image
#from asdff.sd import AdCnPreloadPipe
from typing import Union, Optional, List, Tuple, Dict, Any, Callable
import logging
from diffusers import AutoPipelineForImage2Image
import diffusers
import logging, diffusers, copy, warnings
logging.getLogger("diffusers").setLevel(logging.ERROR)
#logging.getLogger("transformers").setLevel(logging.ERROR)
diffusers.utils.logging.set_verbosity(40)
import copy
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers")

# =====================================
# Utils preprocessor
Expand Down Expand Up @@ -259,16 +255,16 @@ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
"LMS": (LMSDiscreteScheduler, {}),
"LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"DDIM": (DDIMScheduler, {}),
"DEISMultistep": (DEISMultistepScheduler, {}),
"UniPCMultistep": (UniPCMultistepScheduler, {}),
"DEIS": (DEISMultistepScheduler, {}),
"UniPC": (UniPCMultistepScheduler, {}),
"PNDM" : (PNDMScheduler, {}),

"LCM" : (LCMScheduler, {}),

"DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True}),
"DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"euler_at_final": True}),
"DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
"DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),

"LCM" : (LCMScheduler, {}),
}

scheduler_names = list(SCHEDULER_CONFIG_MAP.keys())
Expand Down Expand Up @@ -433,6 +429,8 @@ def load_pipe(
use_safetensors=True,
add_watermarker=False,
)
self.base_model_id = base_model_id
self.class_name = class_name

# Load VAE after loaded model
if vae_model is None :
Expand All @@ -452,6 +450,7 @@ def load_pipe(
self.pipe.vae.to(self.type_model_precision)
except:
logger.warning(f"VAE: not in {self.type_model_precision}")
self.vae_model = vae_model

# Define base scheduler
self.default_scheduler = copy.deepcopy(self.pipe.scheduler)
Expand Down Expand Up @@ -609,13 +608,6 @@ def load_controlnet_weight(self, task_name: str) -> None:
self.pipe.controlnet = controlnet
#self.task_name = task_name

def get_prompt(self, prompt: str, additional_prompt: str) -> str:
if not prompt:
prompt = additional_prompt
else:
prompt = f"{prompt}, {additional_prompt}"
return prompt

@torch.autocast("cuda")
def run_pipe(
self,
Expand Down Expand Up @@ -1472,7 +1464,7 @@ def __call__(
sampler (str, optional, defaults to "DPM++ 2M"):
The sampler used for the generation process. Available samplers: DPM++ 2M, DPM++ 2M Karras, DPM++ 2M SDE,
DPM++ 2M SDE Karras, DPM++ SDE, DPM++ SDE Karras, DPM2, DPM2 Karras, Euler, Euler a, Heun, LMS, LMS Karras,
DDIM, DEISMultistep, UniPCMultistep, LCM, DPM++ 2M Lu, DPM++ 2M Ef, DPM++ 2M SDE Lu and DPM++ 2M SDE Ef.
DDIM, DEIS, UniPC, DPM2 a, DPM2 a Karras, PNDM, LCM, DPM++ 2M Lu, DPM++ 2M Ef, DPM++ 2M SDE Lu and DPM++ 2M SDE Ef.
syntax_weights (str, optional, defaults to "Classic"):
Specifies the type of syntax weights used during generation. "Classic" is (word:weight), "Compel" is (word)weight
lora_A (str, optional):
Expand Down Expand Up @@ -1619,16 +1611,17 @@ def __call__(
- image_mask
- image_resolution
- strength
- controlnet_conditioning_scale
- control_guidance_start
- control_guidance_end
for SD 1.5:
- controlnet_conditioning_scale
- control_guidance_start
- control_guidance_end
Additional parameters that will be used in img2img:
- image
- image_resolution
- strength
Additional parameters that will be used in ControlNet depending on the task:
Additional parameters that will be used in ControlNet for SD 1.5 depending on the task:
- image
- preprocessor_name
- preprocess_resolution
Expand All @@ -1643,7 +1636,7 @@ def __call__(
- value_threshold
- distance_threshold
Additional parameters that will be used in T2I adapter depending on the task:
Additional parameters that will be used in T2I adapter for SDXL depending on the task:
- image
- preprocess_resolution
- image_resolution
Expand Down Expand Up @@ -2022,7 +2015,7 @@ def __call__(
if self.class_name == "StableDiffusionPipeline":
# Base params pipe sd
pipe_params_config = {
"prompt": None, # prompt, #self.get_prompt(prompt, additional_prompt),
"prompt": None, # prompt,
"negative_prompt": None, # negative_prompt,
"prompt_embeds": prompt_emb,
"negative_prompt_embeds": negative_prompt_emb,
Expand Down

0 comments on commit ab694be

Please sign in to comment.