Skip to content

Commit

Permalink
Add tokenizer_kwargs in SentenceTransformer init
Browse files Browse the repository at this point in the history
  • Loading branch information
lsz05 committed Jun 18, 2024
1 parent 32fed50 commit 6a93870
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/jmteb/embedders/sbert_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,21 @@ def __init__(
device: str | None = None,
normalize_embeddings: bool = False,
max_seq_length: int | None = None,
tokenizer_padding_side: str | None = None,
add_eos: bool = False,
tokenizer_kwargs: dict | None = None,
) -> None:
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True)
self.model = SentenceTransformer(model_name_or_path, trust_remote_code=True, tokenizer_kwargs=tokenizer_kwargs)
if max_seq_length:
self.model.max_seq_length = max_seq_length
if tokenizer_padding_side:
try:
self.model.tokenizer.padding_side = "right"
except AttributeError:
pass

self.batch_size = batch_size
self.device = device
self.normalize_embeddings = normalize_embeddings
self.max_seq_length = max_seq_length
self.tokenizer_padding_side = tokenizer_padding_side
self.add_eos = add_eos

if self.max_seq_length:
self.model.max_seq_length = self.max_seq_length
if self.tokenizer_padding_side:
setattr(self.model.tokenizer, "padding_side", self.tokenizer_padding_side)

def encode(self, text: str | list[str], prompt: str | None = None) -> np.ndarray:
if self.add_eos:
Expand Down

0 comments on commit 6a93870

Please sign in to comment.