diff --git a/src/jmteb/embedders/data_parallel_sbert_embedder.py b/src/jmteb/embedders/data_parallel_sbert_embedder.py index 4cde996..d1b0b11 100644 --- a/src/jmteb/embedders/data_parallel_sbert_embedder.py +++ b/src/jmteb/embedders/data_parallel_sbert_embedder.py @@ -85,13 +85,11 @@ def encode( # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling # Tracking the prompt length allow us to remove the prompt during pooling tokenized_prompt = self.sbert.tokenize([prompt]) - if "input_ids" in tokenized_prompt: - extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 - # When `include_prompt` is False in Pooling, prompt_length is unnecessary and should be removed. - # This prevents problems arising from DataParallel - if self.include_prompt_for_pooling(): - _ = extra_features.pop("prompt_length") + # When `include_prompt` is True in Pooling, prompt_length is unnecessary and should be removed. + # This prevents problems arising from DataParallel + if "input_ids" in tokenized_prompt and not self.include_prompt_for_pooling(): + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 all_embeddings = [] length_sorted_idx = np.argsort([-self.sbert._text_length(sen) for sen in sentences])