Skip to content

Commit

Permalink
add support for SentenceTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyash2106 committed Jun 3, 2024
1 parent 4377caf commit f99d767
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 7 deletions.
40 changes: 37 additions & 3 deletions src/agrag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import yaml

from agrag.defaults import DATA_PROCESSING_MODULE_DEFAULTS
from agrag.defaults import DATA_PROCESSING_MODULE_DEFAULTS, EMBEDDING_MODULE_DEFAULTS
from agrag.modules.data_processing.data_processing import DataProcessingModule
from agrag.modules.embedding.embedding import EmbeddingModule
from agrag.modules.generator.generator import GeneratorModule
Expand All @@ -19,12 +19,16 @@

def get_defaults_from_config():
DATA_PROCESSING_MODULE_CONFIG = os.path.join(CURRENT_DIR, "configs/data_processing/default.yaml")
global DATA_PROCESSING_MODULE_DEFAULTS
EMBEDDING_MODULE_CONFIG = os.path.join(CURRENT_DIR, "configs/embedding/default.yaml")
global DATA_PROCESSING_MODULE_DEFAULTS, EMBEDDING_MODULE_DEFAULTS
with open(DATA_PROCESSING_MODULE_CONFIG, "r") as f:
doc = yaml.safe_load(f)
DATA_PROCESSING_MODULE_DEFAULTS = dict(
(k, v if v else doc["data"][k]) for k, v in DATA_PROCESSING_MODULE_DEFAULTS.items()
)
with open(EMBEDDING_MODULE_CONFIG, "r") as f:
doc = yaml.safe_load(f)
EMBEDDING_MODULE_DEFAULTS = dict((k, v if v else doc["data"][k]) for k, v in EMBEDDING_MODULE_DEFAULTS.items())


def get_args() -> argparse.Namespace:
Expand Down Expand Up @@ -60,6 +64,30 @@ def get_args() -> argparse.Namespace:
required=False,
default=DATA_PROCESSING_MODULE_DEFAULTS["CHUNK_OVERLAP"],
)
parser.add_argument(
"--hf_embedding_model",
type=str,
help="Huggingface model to use for generating embeddings",
metavar="",
required=False,
default=EMBEDDING_MODULE_DEFAULTS["HF_DEFAULT_MODEL"],
)
parser.add_argument(
"--st_embedding_model",
type=str,
help="Sentence Transformer model to use for generating embeddings",
metavar="",
required=False,
default=EMBEDDING_MODULE_DEFAULTS["ST_DEFAULT_MODEL"],
)
parser.add_argument(
"--pooling_strategy",
type=str,
help="Pooling method to use when pooling the embeddings generated by the embedding model",
metavar="",
required=False,
default=None,
)

args = parser.parse_args()
return args
Expand All @@ -73,14 +101,20 @@ def initialize_rag_pipeline() -> RetrieverModule:
chunk_size = args.chunk_size
chunk_overlap = args.chunk_overlap
s3_bucket = args.s3_bucket
hf_embedding_model = args.hf_embedding_model
st_embedding_model = args.st_embedding_model

pooling_strategy = args.pooling_strategy

logger.info(f"Processing Data from provided documents at {data_dir}")
data_processing_module = DataProcessingModule(
data_dir=data_dir, chunk_size=chunk_size, chunk_overlap=chunk_overlap, s3_bucket=s3_bucket
)
processed_data = data_processing_module.process_data()

embedding_module = EmbeddingModule()
embedding_module = EmbeddingModule(
hf_model=hf_embedding_model, st_model=st_embedding_model, pooling_strategy=pooling_strategy
)
embeddings = embedding_module.create_embeddings(processed_data)

vector_database_module = VectorDatabaseModule()
Expand Down
86 changes: 82 additions & 4 deletions src/agrag/modules/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,84 @@
import logging
from typing import List, Union

import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer

from agrag.modules.embedding.utils import pool

logger = logging.getLogger("rag-logger")


class EmbeddingModule:
def __init__(self):
pass
"""
A class used to generate embeddings for text dat.
Attributes:
----------
model_name : str
The name of the Huggingface model or SentenceTransformer to use for generating embeddings.
tokenizer : transformers.PreTrainedTokenizer
The tokenizer associated with the Huggingface model.
model : transformers.PreTrainedModel
The Huggingface model used for generating embeddings.
pooling_strategy : str
The strategy used for pooling embeddings. Options are 'average', 'max', 'cls'.
If no option is provided, will default to using no pooling method.
Methods:
-------
create_embeddings(data: List[str]) -> List[torch.Tensor]:
Generates embeddings for a list of text data chunks.
"""

def __init__(
self,
hf_model: str = "BAAI/bge-large-en",
st_model: str = "paraphrase-MiniLM-L6-v2",
pooling_strategy: str = None,
):
self.sentence_transf = False
self.hf_model = hf_model
self.st_model = st_model
if st_model == "sentence_transformer":
self.model = SentenceTransformer(self.st_model)
self.sentence_transf = True
else:
logger.info(f"Default to using Huggingface since no model was provided.")
logger.info(f"Using Huggingface Model: {self.hf_model}")
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model)
self.model = AutoModel.from_pretrained(self.hf_model)
self.pooling_strategy = pooling_strategy

def create_embeddings(self, data: List[str]) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Generates embeddings for a list of text data chunks.
Parameters:
----------
data : List[str]
A list of text data chunks to generate embeddings for.
def create_embeddings(self, data):
pass
Returns:
-------
Union[List[torch.Tensor], torch.Tensor]
A list of embeddings corresponding to the input data chunks if pooling_strategy is 'none',
otherwise a single tensor with the pooled embeddings.
"""
if self.sentence_transf:
embeddings = self.model.encode(data, convert_to_tensor=True)
embeddings = pool(embeddings, self.pooling_strategy)
else:
embeddings = []
for text in data:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = self.model(**inputs)
embedding = pool(outputs.last_hidden_state, self.pooling_strategy)
embeddings.append(embedding)
if not self.pooling_strategy:
return embeddings
else:
# Combine pooled embeddings into a single tensor
return torch.cat(embeddings, dim=0)

0 comments on commit f99d767

Please sign in to comment.