From 6a93870cbc7f51914d0311f3aafb30e8272db381 Mon Sep 17 00:00:00 2001 From: "shengzhe.li" Date: Tue, 18 Jun 2024 12:58:04 +0900 Subject: [PATCH] Add tokenizer_kwargs in SentenceTransformer init --- src/jmteb/embedders/sbert_embedder.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 6fbc48e..dbed505 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -16,29 +16,21 @@ def __init__( device: str | None = None, normalize_embeddings: bool = False, max_seq_length: int | None = None, - tokenizer_padding_side: str | None = None, add_eos: bool = False, + tokenizer_kwargs: dict | None = None, ) -> None: - self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True) + self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs) if max_seq_length: self.model.max_seq_length = max_seq_length - if tokenizer_padding_side: - try: - self.model.tokenizer.padding_side = "right" - except AttributeError: - pass self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings self.max_seq_length = max_seq_length - self.tokenizer_padding_side = tokenizer_padding_side self.add_eos = add_eos if self.max_seq_length: self.model.max_seq_length = self.max_seq_length - if self.tokenizer_padding_side: - setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side) def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: if self.add_eos: