Skip to content

Commit

Permalink
feat: trailing variant
Browse files Browse the repository at this point in the history
  • Loading branch information
R3gm committed Jun 14, 2024
1 parent 430c660 commit 343a220
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
20 changes: 14 additions & 6 deletions stablepy/diffusers_vanilla/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@

FLASH_LORA = {
"StableDiffusionPipeline": {
"LCM": "latent-consistency/lcm-lora-sdv1-5",
"TCD": "h1t/TCD-SD15-LoRA",
"LCM Auto-Loader": "latent-consistency/lcm-lora-sdv1-5",
"TCD Auto-Loader": "h1t/TCD-SD15-LoRA",
},
"StableDiffusionXLPipeline": {
"LCM": "latent-consistency/lcm-lora-sdxl",
"TCD": "h1t/TCD-SDXL-LoRA",
"LCM Auto-Loader": "latent-consistency/lcm-lora-sdxl",
"TCD Auto-Loader": "h1t/TCD-SDXL-LoRA",
},
}

Expand Down Expand Up @@ -107,11 +107,14 @@
"KDPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
"Euler": (EulerDiscreteScheduler, {}),
"Euler a": (EulerAncestralDiscreteScheduler, {}),
"Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}),
"Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}),
"Heun": (HeunDiscreteScheduler, {}),
"Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
"LMS": (LMSDiscreteScheduler, {}),
"LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
"DDIM": (DDIMScheduler, {}),
"DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}),
"DEIS": (DEISMultistepScheduler, {}),
"UniPC": (UniPCMultistepScheduler, {}),
"UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}),
Expand All @@ -127,11 +130,16 @@
"DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
"DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),

"TCD": (TCDScheduler, {}),
"LCM": (LCMScheduler, {}),
"TCD": (TCDScheduler, {}),
"LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}),
"TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}),
"LCM Auto-Loader": (LCMScheduler, {}),
"TCD Auto-Loader": (TCDScheduler, {}),
}

scheduler_names = list(SCHEDULER_CONFIG_MAP.keys())
FLASH_AUTO_LOAD_SAMPLER = scheduler_names[-2:]

IP_ADAPTER_MODELS = {
"StableDiffusionPipeline": {
Expand Down Expand Up @@ -718,7 +726,7 @@ def name_list_ip_adapters(model_key):
"negative_prompt": "Meticulous craftsmanship, figural realism, premeditated approach"
},
{
"name": "",
"name": "Pure Typography",
"prompt": "Lettrist artwork {prompt} . Avant-garde letters freed into pure form, visual rhythm and texture of typography as sole communication",
"negative_prompt": "Readability, coherence, concrete meaning beyond visual experience"
},
Expand Down
20 changes: 15 additions & 5 deletions stablepy/diffusers_vanilla/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
REPO_IMAGE_ENCODER,
PROMPT_WEIGHT_OPTIONS,
OLD_PROMPT_WEIGHT_OPTIONS,
FLASH_AUTO_LOAD_SAMPLER,
)
from .multi_emphasis_prompt import long_prompts_with_weighting
from diffusers.utils import load_image
Expand Down Expand Up @@ -283,6 +284,7 @@ def __init__(
task_name: str = "txt2img",
vae_model=None,
type_model_precision=torch.float16,
retain_task_model_in_cache=True,
):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.base_model_id = ""
Expand All @@ -293,7 +295,11 @@ def __init__(
) # For SD 1.5

self.load_pipe(
base_model_id, task_name, vae_model, type_model_precision
base_model_id,
task_name,
vae_model,
type_model_precision,
retain_task_model_in_cache,
)
self.preprocessor = Preprocessor()

Expand All @@ -311,7 +317,7 @@ def load_pipe(
vae_model=None,
type_model_precision=torch.float16,
reload=False,
retain_model_in_memory=True,
retain_task_model_in_cache=True,
) -> DiffusionPipeline:
if (
base_model_id == self.base_model_id
Expand Down Expand Up @@ -633,7 +639,7 @@ def load_pipe(
self.pipe.enable_vae_tiling()
self.pipe.watermark = None

if retain_model_in_memory is True and task_name not in self.model_memory:
if retain_task_model_in_cache is True and task_name not in self.model_memory:
self.model_memory[task_name] = self.pipe

return
Expand Down Expand Up @@ -1139,6 +1145,10 @@ def load_beta_styles(self):
self.STYLE_NAMES = STYLE_NAMES
self.style_json_file = ""

logger.info(
f"Beta styles loaded with {len(self.STYLE_NAMES)} styles"
)

def set_ip_adapter_multimode_scale(self, ip_scales, ip_adapter_mode):
mode_scales = []
for scale, mode in zip(ip_scales, ip_adapter_mode):
Expand Down Expand Up @@ -1750,13 +1760,13 @@ def __call__(
lora_scale_E,
]

if sampler in ["TCD", "LCM"] and self.flash_config is None:
if sampler in FLASH_AUTO_LOAD_SAMPLER and self.flash_config is None:
# First load
flash_task_lora = FLASH_LORA[self.class_name][sampler]
self.process_lora(flash_task_lora, 1.0)
self.flash_config = flash_task_lora
logger.info(sampler)
elif sampler not in ["TCD", "LCM"] and self.flash_config is not None:
elif sampler not in FLASH_AUTO_LOAD_SAMPLER and self.flash_config is not None:
# Unload
self.process_lora(self.flash_config, 1.0, unload=True)
self.flash_config = None
Expand Down

0 comments on commit 343a220

Please sign in to comment.