Skip to content

Commit

Permalink
feat: accept and validate pipeline params to enable features like gui…
Browse files Browse the repository at this point in the history
…dance_rescale
  • Loading branch information
R3gm committed Dec 1, 2024
1 parent 088f7d6 commit be22637
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
32 changes: 16 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@ python = "^3.10"
torch = {version = "*", source = "pytorch-gpu-src"}
torchvision = {version = "*", source = "pytorch-gpu-src"}
torchaudio = {version = "*", source = "pytorch-gpu-src"}
omegaconf = "2.3.0"
omegaconf = ">=2.3.0"
diffusers = "0.31.0"
compel = "2.0.2"
invisible-watermark = "^0.2.0"
transformers = "^4.41.2"
accelerate = "^0.31.0"
safetensors = "0.4.3"
mediapy = "^1.1.9"
invisible-watermark = "0.2.0"
transformers = ">=4.41.2"
accelerate = ">=0.31.0"
safetensors = ">=0.4.3"
mediapy = ">=1.1.9"
ipywidgets = "7.7.1"
controlnet-aux = "^0.0.6"
mediapipe = "0.10.1"
pytorch-lightning = "^2.0.9.post0"
ultralytics = "^8.0.0"
huggingface_hub = "^0.23.1"
peft = "^0.11.1"
torchsde = "^0.2.6"
onnxruntime = "^1.18.0"
insightface = "^0.7.3"
opencv-contrib-python = "^4.8.0.76"
controlnet-aux = "0.0.6"
mediapipe = ">=0.10.1"
pytorch-lightning = ">=2.0.9.post0"
ultralytics = ">=8.0.0"
huggingface_hub = ">=0.23.1"
peft = ">=0.11.1"
torchsde = ">=0.2.6"
onnxruntime = ">=1.18.0"
insightface = ">=0.7.3"
opencv-contrib-python = ">=4.8.0.76"
sentencepiece = "*"
numpy = "<=1.26.4"
lark = "*"
Expand Down
35 changes: 35 additions & 0 deletions stablepy/diffusers_vanilla/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
check_variant_file,
cachebox,
release_resources,
validate_and_update_params,
)
from .lora_loader import lora_mix_load, load_no_fused_lora
from .inpainting_canvas import draw, make_inpaint_condition
Expand Down Expand Up @@ -233,6 +234,32 @@ def switch_pipe_class(
verbose_info=False,
):

if hasattr(self.pipe, "set_pag_applied_layers"):
if not hasattr(self.pipe, "text_encoder_2"):
self.pipe = StableDiffusionPipeline(
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
unet=self.pipe.unet,
scheduler=self.pipe.scheduler,
safety_checker=self.pipe.safety_checker,
feature_extractor=self.pipe.feature_extractor,
image_encoder=self.pipe.image_encoder,
requires_safety_checker=self.pipe.config.requires_safety_checker,
)
else:
self.pipe = StableDiffusionXLPipeline(
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
text_encoder_2=self.pipe.text_encoder_2,
tokenizer=self.pipe.tokenizer,
tokenizer_2=self.pipe.tokenizer_2,
unet=self.pipe.unet,
scheduler=self.pipe.scheduler,
feature_extractor=self.pipe.feature_extractor,
image_encoder=self.pipe.image_encoder,
)

tk = "base"
model_components = dict(
vae=self.pipe.vae,
Expand Down Expand Up @@ -1284,6 +1311,7 @@ def __call__(
image_previews: bool = False,
xformers_memory_efficient_attention: bool = False,
gui_active: bool = False,
**kwargs,
):

"""
Expand Down Expand Up @@ -2468,6 +2496,7 @@ def _execution_device(self):
hires_params_config,
hires_pipe,
metadata,
kwargs,
)

def post_processing(
Expand Down Expand Up @@ -2653,6 +2682,7 @@ def start_work(
hires_params_config,
hires_pipe,
metadata,
kwargs,
):
for i in range(loop_generation):
# number seed
Expand Down Expand Up @@ -2696,6 +2726,8 @@ def start_work(
pipe_params_config["output_type"] = "latent"
# self.pipe.__class__._execution_device = property(_execution_device)

validate_and_update_params(self.pipe.__class__, kwargs, pipe_params_config)

try:
images = self.pipe(
**pipe_params_config,
Expand Down Expand Up @@ -2778,6 +2810,7 @@ def start_stream(
hires_params_config,
hires_pipe,
metadata,
kwargs,
):
for i in range(loop_generation):
# number seed
Expand Down Expand Up @@ -2821,6 +2854,8 @@ def start_stream(
# self.pipe.__class__._execution_device = property(_execution_device)
self.metadata = metadata

validate_and_update_params(self.pipe.__class__, kwargs, pipe_params_config)

try:
logger.debug("Start stream")
# self.stream_config(5)
Expand Down
24 changes: 24 additions & 0 deletions stablepy/diffusers_vanilla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hashlib
from collections import OrderedDict
import gc
import inspect


def generate_lora_tags(names_list, scales_list):
Expand Down Expand Up @@ -319,6 +320,29 @@ def latents_to_rgb(latents, latent_resize, vae_decoding, pipe):
return resized_image


def validate_and_update_params(cls, kwargs, config):
"""
Validates kwargs against the parameters of a given class's `__call__` method
and updates the provided configuration dictionary.
Args:
cls: The class whose `__call__` method parameters are used for validation.
kwargs (dict): The keyword arguments to validate.
config (dict): The dictionary to update with valid parameters.
"""
if kwargs:
# logger.debug(kwargs)
valid_params = inspect.signature(cls.__call__).parameters.keys()
for name_param, value_param in kwargs.items():
if name_param in valid_params:
config.update({name_param: value_param})
logger.debug(f"Parameter added: '{name_param}': {value_param}.")
else:
logger.error(
f"The pipeline '{cls.__name__}' had an invalid parameter"
f" removed: '{name_param}'.")


def cachebox(max_cache_size=None, hash_func=hashlib.md5):
"""Alternative to lru_cache"""

Expand Down

0 comments on commit be22637

Please sign in to comment.