From d94a1f9483f54f8b324f666c3211959af0dde095 Mon Sep 17 00:00:00 2001 From: Teriks Date: Tue, 10 Sep 2024 23:56:48 -0500 Subject: [PATCH] stable cascade support, new ReturnedEmbeddingsType --- src/compel/embeddings_provider.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/compel/embeddings_provider.py b/src/compel/embeddings_provider.py index e5f47b9..e76bc7a 100644 --- a/src/compel/embeddings_provider.py +++ b/src/compel/embeddings_provider.py @@ -22,6 +22,7 @@ class ReturnedEmbeddingsType(Enum): LAST_HIDDEN_STATES_NORMALIZED = 0 # SD1/2 regular PENULTIMATE_HIDDEN_STATES_NORMALIZED = 1 # SD1.5 with "clip skip" PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED = 2 # SDXL + STABLE_CASCADE = 3 # Stable Cascade class EmbeddingsProvider: @@ -234,7 +235,7 @@ def get_token_ids(self, texts: List[str], include_start_and_end_markers: bool = return result def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor]=None, device: Optional[str]=None) -> Optional[torch.Tensor]: - + device = device or self.device token_ids = self.get_token_ids(texts, padding="max_length", truncation_override=True) @@ -243,7 +244,10 @@ def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True) pooled = text_encoder_output.text_embeds - return pooled + if self.returned_embeddings_type is ReturnedEmbeddingsType.STABLE_CASCADE: + return pooled.unsqueeze(1) + else: + return pooled def get_token_ids_and_expand_weights(self, fragments: List[str], weights: List[float], device: str @@ -386,7 +390,8 @@ def build_weighted_embedding_tensor(self, def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]=None) -> torch.Tensor: needs_hidden_states = (self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED or - self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED) + self.returned_embeddings_type == ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED or + self.returned_embeddings_type == ReturnedEmbeddingsType.STABLE_CASCADE) text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, @@ -400,6 +405,9 @@ def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor, elif self.returned_embeddings_type is ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED: # already normalized return text_encoder_output.last_hidden_state + elif self.returned_embeddings_type is ReturnedEmbeddingsType.STABLE_CASCADE: + # last_hidden_state attribute does not work, non-intuitive + return text_encoder_output.hidden_states[-1] assert False, f"unrecognized ReturnEmbeddingsType: {self.returned_embeddings_type}"