diff --git a/src/agrag/main.py b/src/agrag/main.py index 8222dcfc..7e613075 100644 --- a/src/agrag/main.py +++ b/src/agrag/main.py @@ -4,7 +4,7 @@ import yaml -from agrag.defaults import DATA_PROCESSING_MODULE_DEFAULTS +from agrag.defaults import DATA_PROCESSING_MODULE_DEFAULTS, EMBEDDING_MODULE_DEFAULTS from agrag.modules.data_processing.data_processing import DataProcessingModule from agrag.modules.embedding.embedding import EmbeddingModule from agrag.modules.generator.generator import GeneratorModule @@ -19,12 +19,16 @@ def get_defaults_from_config(): DATA_PROCESSING_MODULE_CONFIG = os.path.join(CURRENT_DIR, "configs/data_processing/default.yaml") - global DATA_PROCESSING_MODULE_DEFAULTS + EMBEDDING_MODULE_CONFIG = os.path.join(CURRENT_DIR, "configs/embedding/default.yaml") + global DATA_PROCESSING_MODULE_DEFAULTS, EMBEDDING_MODULE_DEFAULTS with open(DATA_PROCESSING_MODULE_CONFIG, "r") as f: doc = yaml.safe_load(f) DATA_PROCESSING_MODULE_DEFAULTS = dict( (k, v if v else doc["data"][k]) for k, v in DATA_PROCESSING_MODULE_DEFAULTS.items() ) + with open(EMBEDDING_MODULE_CONFIG, "r") as f: + doc = yaml.safe_load(f) + EMBEDDING_MODULE_DEFAULTS = dict((k, v if v else doc["data"][k]) for k, v in EMBEDDING_MODULE_DEFAULTS.items()) def get_args() -> argparse.Namespace: @@ -60,6 +64,30 @@ def get_args() -> argparse.Namespace: required=False, default=DATA_PROCESSING_MODULE_DEFAULTS["CHUNK_OVERLAP"], ) + parser.add_argument( + "--hf_embedding_model", + type=str, + help="Huggingface model to use for generating embeddings", + metavar="", + required=False, + default=EMBEDDING_MODULE_DEFAULTS["HF_DEFAULT_MODEL"], + ) + parser.add_argument( + "--st_embedding_model", + type=str, + help="Sentence Transformer model to use for generating embeddings", + metavar="", + required=False, + default=EMBEDDING_MODULE_DEFAULTS["ST_DEFAULT_MODEL"], + ) + parser.add_argument( + "--pooling_strategy", + type=str, + help="Pooling method to use when pooling the embeddings generated by the embedding model", + metavar="", + required=False, + default=None, + ) args = parser.parse_args() return args @@ -73,6 +101,10 @@ def initialize_rag_pipeline() -> RetrieverModule: chunk_size = args.chunk_size chunk_overlap = args.chunk_overlap s3_bucket = args.s3_bucket + hf_embedding_model = args.hf_embedding_model + st_embedding_model = args.st_embedding_model + + pooling_strategy = args.pooling_strategy logger.info(f"Processing Data from provided documents at {data_dir}") data_processing_module = DataProcessingModule( @@ -80,7 +112,9 @@ def initialize_rag_pipeline() -> RetrieverModule: ) processed_data = data_processing_module.process_data() - embedding_module = EmbeddingModule() + embedding_module = EmbeddingModule( + hf_model=hf_embedding_model, st_model=st_embedding_model, pooling_strategy=pooling_strategy + ) embeddings = embedding_module.create_embeddings(processed_data) vector_database_module = VectorDatabaseModule() diff --git a/src/agrag/modules/embedding/embedding.py b/src/agrag/modules/embedding/embedding.py index f148f3b5..c3b8ff05 100644 --- a/src/agrag/modules/embedding/embedding.py +++ b/src/agrag/modules/embedding/embedding.py @@ -1,6 +1,84 @@ +import logging +from typing import List, Union + +import torch +from sentence_transformers import SentenceTransformer +from transformers import AutoModel, AutoTokenizer + +from agrag.modules.embedding.utils import pool + +logger = logging.getLogger("rag-logger") + + class EmbeddingModule: - def __init__(self): - pass + """ + A class used to generate embeddings for text dat. + + Attributes: + ---------- + model_name : str + The name of the Huggingface model or SentenceTransformer to use for generating embeddings. + tokenizer : transformers.PreTrainedTokenizer + The tokenizer associated with the Huggingface model. + model : transformers.PreTrainedModel + The Huggingface model used for generating embeddings. + pooling_strategy : str + The strategy used for pooling embeddings. Options are 'average', 'max', 'cls'. + If no option is provided, will default to using no pooling method. + + Methods: + ------- + create_embeddings(data: List[str]) -> List[torch.Tensor]: + Generates embeddings for a list of text data chunks. + """ + + def __init__( + self, + hf_model: str = "BAAI/bge-large-en", + st_model: str = "paraphrase-MiniLM-L6-v2", + pooling_strategy: str = None, + ): + self.sentence_transf = False + self.hf_model = hf_model + self.st_model = st_model + if st_model == "sentence_transformer": + self.model = SentenceTransformer(self.st_model) + self.sentence_transf = True + else: + logger.info(f"Default to using Huggingface since no model was provided.") + logger.info(f"Using Huggingface Model: {self.hf_model}") + self.tokenizer = AutoTokenizer.from_pretrained(self.hf_model) + self.model = AutoModel.from_pretrained(self.hf_model) + self.pooling_strategy = pooling_strategy + + def create_embeddings(self, data: List[str]) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Generates embeddings for a list of text data chunks. + + Parameters: + ---------- + data : List[str] + A list of text data chunks to generate embeddings for. - def create_embeddings(self, data): - pass + Returns: + ------- + Union[List[torch.Tensor], torch.Tensor] + A list of embeddings corresponding to the input data chunks if pooling_strategy is 'none', + otherwise a single tensor with the pooled embeddings. + """ + if self.sentence_transf: + embeddings = self.model.encode(data, convert_to_tensor=True) + embeddings = pool(embeddings, self.pooling_strategy) + else: + embeddings = [] + for text in data: + inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True) + with torch.no_grad(): + outputs = self.model(**inputs) + embedding = pool(outputs.last_hidden_state, self.pooling_strategy) + embeddings.append(embedding) + if not self.pooling_strategy: + return embeddings + else: + # Combine pooled embeddings into a single tensor + return torch.cat(embeddings, dim=0)