From 79cbc906061dc93f8e40434c2b875330edd9695a Mon Sep 17 00:00:00 2001 From: BioMike BioMikeUkr <77461417+BioMikeUkr@users.noreply.github.com> Date: Tue, 28 Jan 2025 22:17:25 +0200 Subject: [PATCH 1/2] fixed normalize_embeddings when Normalize presented in modules --- sentence_transformers/SentenceTransformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index c497dadbc..8758df3ca 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -685,6 +685,8 @@ def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]: return super().forward(input) for module_name, module in self.named_children(): + if not kwargs.get("normalize_embeddings", False) and isinstance(module, Normalize): + return input module_kwarg_keys = self.module_kwargs.get(module_name, []) module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys} input = module(input, **module_kwargs) From 23b6c0f4f4e4b5ff7f252098341ae56095f6b4e1 Mon Sep 17 00:00:00 2001 From: werent4 Date: Tue, 28 Jan 2025 22:25:31 +0200 Subject: [PATCH 2/2] Add missing normalize_embeddings to kwargs --- sentence_transformers/SentenceTransformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 8758df3ca..17b6bd38b 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -526,6 +526,7 @@ def encode( print(embeddings.shape) # (3, 768) """ + kwargs["normalize_embeddings"] = normalize_embeddings if self.device.type == "hpu" and not self.is_hpu_graph_enabled: import habana_frameworks.torch as ht