-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTransformerEmbedder.py
73 lines (60 loc) · 3.08 KB
/
TransformerEmbedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import List, Optional
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from numpy import ndarray
from sentence_transformers import SentenceTransformer
class TransformerEmbedder:
"""
A class to generate embeddings for documents and queries using a pre-trained transformer model.
The `TransformerEmbedder` class utilizes the `SentenceTransformer` model to encode texts (documents or queries) into vector embeddings. This is useful for various natural language processing tasks such as semantic search, clustering, and classification.
Attributes:
-----------
model : SentenceTransformer
The transformer model used for generating embeddings.
embed_document_prompt : Optional[str]
The prompt used by the model when embedding documents. If None, the model's default document embedding prompt is used.
embed_query_prompt : Optional[str]
The prompt used by the model when embedding queries. Defaults to "query".
embedding_batch_size : int
The batch size for embedding texts. Larger batch sizes may improve throughput but require more memory.
Methods:
--------
embed_documents(texts: List[str]) -> ndarray:
Generates embeddings for a list of documents.
Parameters:
-----------
texts : List[str]
A list of document strings to be embedded. If any element in the list is None, it will be replaced with an empty string before embedding.
Returns:
--------
ndarray
A NumPy array of embeddings corresponding to the input documents.
embed_query(query: str) -> List[float]:
Generates an embedding for a single query string.
Parameters:
-----------
query : str
The query string to be embedded.
Returns:
--------
List[float]
A list representing the embedding vector for the query.
"""
def __init__(self,
model_name: str = "Snowflake/snowflake-arctic-embed-m-v1.5",
embed_document_prompt: Optional[str] = None,
embed_query_prompt: Optional[str] = "query",
embedding_batch_size: int = 32):
self.model = SentenceTransformer(model_name, trust_remote_code=True)
self.embed_document_prompt = embed_document_prompt
self.embed_query_prompt = embed_query_prompt
self.embedding_batch_size = embedding_batch_size
def embed_documents(self, texts: List[str]) -> ndarray:
texts = [text if text is not None else "" for text in texts]
embeddings = self.model.encode(texts,
batch_size=self.embedding_batch_size,
prompt_name=self.embed_document_prompt)
return embeddings
def embed_query(self, query: str) -> List[float]:
embedding = self.model.encode(query, prompt_name=self.embed_query_prompt)
return embedding.tolist()