diff --git a/src/jmteb/embedders/sbert_embedder.py b/src/jmteb/embedders/sbert_embedder.py index 6fbc48e..dbed505 100644 --- a/src/jmteb/embedders/sbert_embedder.py +++ b/src/jmteb/embedders/sbert_embedder.py @@ -16,29 +16,21 @@ def __init__( device: str | None = None, normalize_embeddings: bool = False, max_seq_length: int | None = None, - tokenizer_padding_side: str | None = None, add_eos: bool = False, + tokenizer_kwargs: dict | None = None, ) -> None: - self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True) + self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs) if max_seq_length: self.model.max_seq_length = max_seq_length - if tokenizer_padding_side: - try: - self.model.tokenizer.padding_side = "right" - except AttributeError: - pass self.batch_size = batch_size self.device = device self.normalize_embeddings = normalize_embeddings self.max_seq_length = max_seq_length - self.tokenizer_padding_side = tokenizer_padding_side self.add_eos = add_eos if self.max_seq_length: self.model.max_seq_length = self.max_seq_length - if self.tokenizer_padding_side: - setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side) def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray: if self.add_eos: