diff --git a/pyproject.toml b/pyproject.toml index 910ebad..eedb2df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,25 +10,25 @@ python = "^3.10" torch = {version = "*", source = "pytorch-gpu-src"} torchvision = {version = "*", source = "pytorch-gpu-src"} torchaudio = {version = "*", source = "pytorch-gpu-src"} -omegaconf = "2.3.0" +omegaconf = ">=2.3.0" diffusers = "0.31.0" compel = "2.0.2" -invisible-watermark = "^0.2.0" -transformers = "^4.41.2" -accelerate = "^0.31.0" -safetensors = "0.4.3" -mediapy = "^1.1.9" +invisible-watermark = "0.2.0" +transformers = ">=4.41.2" +accelerate = ">=0.31.0" +safetensors = ">=0.4.3" +mediapy = ">=1.1.9" ipywidgets = "7.7.1" -controlnet-aux = "^0.0.6" -mediapipe = "0.10.1" -pytorch-lightning = "^2.0.9.post0" -ultralytics = "^8.0.0" -huggingface_hub = "^0.23.1" -peft = "^0.11.1" -torchsde = "^0.2.6" -onnxruntime = "^1.18.0" -insightface = "^0.7.3" -opencv-contrib-python = "^4.8.0.76" +controlnet-aux = "0.0.6" +mediapipe = ">=0.10.1" +pytorch-lightning = ">=2.0.9.post0" +ultralytics = ">=8.0.0" +huggingface_hub = ">=0.23.1" +peft = ">=0.11.1" +torchsde = ">=0.2.6" +onnxruntime = ">=1.18.0" +insightface = ">=0.7.3" +opencv-contrib-python = ">=4.8.0.76" sentencepiece = "*" numpy = "<=1.26.4" lark = "*" diff --git a/stablepy/diffusers_vanilla/model.py b/stablepy/diffusers_vanilla/model.py index cd384f2..b83a576 100644 --- a/stablepy/diffusers_vanilla/model.py +++ b/stablepy/diffusers_vanilla/model.py @@ -50,6 +50,7 @@ check_variant_file, cachebox, release_resources, + validate_and_update_params, ) from .lora_loader import lora_mix_load, load_no_fused_lora from .inpainting_canvas import draw, make_inpaint_condition @@ -233,6 +234,32 @@ def switch_pipe_class( verbose_info=False, ): + if hasattr(self.pipe, "set_pag_applied_layers"): + if not hasattr(self.pipe, "text_encoder_2"): + self.pipe = StableDiffusionPipeline( + vae=self.pipe.vae, + text_encoder=self.pipe.text_encoder, + tokenizer=self.pipe.tokenizer, + unet=self.pipe.unet, + scheduler=self.pipe.scheduler, + safety_checker=self.pipe.safety_checker, + feature_extractor=self.pipe.feature_extractor, + image_encoder=self.pipe.image_encoder, + requires_safety_checker=self.pipe.config.requires_safety_checker, + ) + else: + self.pipe = StableDiffusionXLPipeline( + vae=self.pipe.vae, + text_encoder=self.pipe.text_encoder, + text_encoder_2=self.pipe.text_encoder_2, + tokenizer=self.pipe.tokenizer, + tokenizer_2=self.pipe.tokenizer_2, + unet=self.pipe.unet, + scheduler=self.pipe.scheduler, + feature_extractor=self.pipe.feature_extractor, + image_encoder=self.pipe.image_encoder, + ) + tk = "base" model_components = dict( vae=self.pipe.vae, @@ -1284,6 +1311,7 @@ def __call__( image_previews: bool = False, xformers_memory_efficient_attention: bool = False, gui_active: bool = False, + **kwargs, ): """ @@ -2468,6 +2496,7 @@ def _execution_device(self): hires_params_config, hires_pipe, metadata, + kwargs, ) def post_processing( @@ -2653,6 +2682,7 @@ def start_work( hires_params_config, hires_pipe, metadata, + kwargs, ): for i in range(loop_generation): # number seed @@ -2696,6 +2726,8 @@ def start_work( pipe_params_config["output_type"] = "latent" # self.pipe.__class__._execution_device = property(_execution_device) + validate_and_update_params(self.pipe.__class__, kwargs, pipe_params_config) + try: images = self.pipe( **pipe_params_config, @@ -2778,6 +2810,7 @@ def start_stream( hires_params_config, hires_pipe, metadata, + kwargs, ): for i in range(loop_generation): # number seed @@ -2821,6 +2854,8 @@ def start_stream( # self.pipe.__class__._execution_device = property(_execution_device) self.metadata = metadata + validate_and_update_params(self.pipe.__class__, kwargs, pipe_params_config) + try: logger.debug("Start stream") # self.stream_config(5) diff --git a/stablepy/diffusers_vanilla/utils.py b/stablepy/diffusers_vanilla/utils.py index dc3af49..3f2d5d9 100644 --- a/stablepy/diffusers_vanilla/utils.py +++ b/stablepy/diffusers_vanilla/utils.py @@ -14,6 +14,7 @@ import hashlib from collections import OrderedDict import gc +import inspect def generate_lora_tags(names_list, scales_list): @@ -319,6 +320,29 @@ def latents_to_rgb(latents, latent_resize, vae_decoding, pipe): return resized_image +def validate_and_update_params(cls, kwargs, config): + """ + Validates kwargs against the parameters of a given class's `__call__` method + and updates the provided configuration dictionary. + + Args: + cls: The class whose `__call__` method parameters are used for validation. + kwargs (dict): The keyword arguments to validate. + config (dict): The dictionary to update with valid parameters. + """ + if kwargs: + # logger.debug(kwargs) + valid_params = inspect.signature(cls.__call__).parameters.keys() + for name_param, value_param in kwargs.items(): + if name_param in valid_params: + config.update({name_param: value_param}) + logger.debug(f"Parameter added: '{name_param}': {value_param}.") + else: + logger.error( + f"The pipeline '{cls.__name__}' had an invalid parameter" + f" removed: '{name_param}'.") + + def cachebox(max_cache_size=None, hash_func=hashlib.md5): """Alternative to lru_cache"""