diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index e35b8d600..483231985 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -262,7 +262,7 @@ class IPAdapter(nn.Module): def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4, is_sdxl=False, is_plus=False, is_full=False, - is_faceid=False, is_instant_id=False): + is_faceid=False, is_instant_id=False, is_instant_style=False): super().__init__() self.clip_embeddings_dim = clip_embeddings_dim @@ -614,7 +614,7 @@ def INPUT_TYPES(s): def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original", noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False, - insightface=None, faceid_v2=False, weight_v2=False, instant_id=False): + insightface=None, faceid_v2=False, weight_v2=False, instant_id=False, instant_style=False): self.dtype = torch.float16 if ldm_patched.modules.model_management.should_use_fp16() else torch.float32 self.device = ldm_patched.modules.model_management.get_torch_device() @@ -624,6 +624,7 @@ def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None self.is_faceid = self.is_portrait or "0.to_q_lora.down.weight" in ipadapter["ip_adapter"] self.is_plus = (self.is_full or "latents" in ipadapter["image_proj"] or "perceiver_resampler.proj_in.weight" in ipadapter["image_proj"]) self.is_instant_id = instant_id + self.is_instant_style = instant_style if self.is_faceid and not insightface: raise Exception('InsightFace must be provided for FaceID models.') @@ -785,18 +786,22 @@ def modifier(cnet, x_noisy, t, cond, batched_number): patch_kwargs["number"] += 1 set_model_patch_replace(work_model, patch_kwargs, ("middle", 0)) else: - for id in [4,5,7,8]: # id of input_blocks that have cross attention - block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth - for index in block_indices: - set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) + if not self.is_instant_style: + for id in [4,5,7,8]: # id of input_blocks that have cross attention + block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth + for index in block_indices: + set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) + patch_kwargs["number"] += 1 + for id in range(6): # id of output_blocks that have cross attention + block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth + for index in block_indices: + set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) + patch_kwargs["number"] += 1 + for index in range(10): + set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) patch_kwargs["number"] += 1 - for id in range(6): # id of output_blocks that have cross attention - block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth - for index in block_indices: - set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) - patch_kwargs["number"] += 1 - for index in range(10): - set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) + else: # InstantStyle + set_model_patch_replace(work_model, patch_kwargs, ("output", 1, 1)) # target_blocks=["up_blocks.0.attentions.1"] patch_kwargs["number"] += 1 return (work_model, ) diff --git a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py index 47b720442..b058f4c4e 100644 --- a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py +++ b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py @@ -52,6 +52,21 @@ def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider ) return cond +class PreprocessorClipVisionWithForInstantStyle(PreprocessorClipVisionForIPAdapter): + def __init__(self, name, url, filename): + super().__init__(name, url, filename) + + def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): + cond = dict( + clip_vision=self.load_clipvision(), + image=numpy_to_pytorch(input_image), + weight_type="original", + noise=0.0, + embeds=None, + unfold_batch=False, + instant_style=True, + ) + return cond class PreprocessorInsightFaceForInstantID(Preprocessor): def __init__(self, name): @@ -97,6 +112,12 @@ def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider filename='CLIP-ViT-bigG.safetensors' )) +add_supported_preprocessor(PreprocessorClipVisionWithForInstantStyle( + name='InstantStyle', + url='https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors', + filename='CLIP-ViT-bigG.safetensors' +)) + add_supported_preprocessor(PreprocessorClipVisionWithInsightFaceForIPAdapter( name='InsightFace+CLIP-H (IPAdapter)', url='https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors', @@ -107,7 +128,6 @@ def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider name='InsightFace (InstantID)', )) - class IPAdapterPatcher(ControlModelPatcher): @staticmethod def try_build_from_state_dict(state_dict, ckpt_path):