From 672e8ba40e1687643e79b52f0cd625afa116e62c Mon Sep 17 00:00:00 2001 From: hppRC Date: Mon, 9 Dec 2024 17:48:51 +0900 Subject: [PATCH] fix: refactor --- src/jmteb/embedders/data_parallel_sbert_embedder.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 4cde996..d1b0b11 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -85,13 +85,11 @@ def encode( # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling # Tracking the prompt length allow us to remove the prompt during pooling tokenized_prompt = self.sbert.tokenize([prompt]) - if "input_ids" in tokenized_prompt: - extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 - # When `include_prompt` is False in Pooling, prompt_length is unnecessary and should be removed. - # This prevents problems arising from DataParallel - if self.include_prompt_for_pooling(): - _ = extra_features.pop("prompt_length") + # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling(): + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 all_embeddings = [] length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences])