From 6f67864b9d961dd036bf1d0459ee7e51121a2540 Mon Sep 17 00:00:00 2001 From: Shreyash Date: Thu, 6 Jun 2024 11:26:57 -0700 Subject: [PATCH] vector db module interface setup --- pyproject.toml | 4 +- requirements.txt | 2 + src/agrag/modules/vector_db/faiss.py | 41 ++++++++++ src/agrag/modules/vector_db/utils.py | 59 +++++++++++++++ .../modules/vector_db/vector_database.py | 74 ++++++++++++++++++- 5 files changed, 175 insertions(+), 5 deletions(-) create mode 100644 src/agrag/modules/vector_db/faiss.py create mode 100644 src/agrag/modules/vector_db/utils.py diff --git a/pyproject.toml b/pyproject.toml index 443d885e..96887cdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,9 @@ dependencies = [ "langchain-community", "langchain-text-splitters", "transformers", - "pypdf" + "pypdf", + "faiss-cpu", + "scikit-learn", ] diff --git a/requirements.txt b/requirements.txt index 2436439b..fdc5fff8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ langchain-community langchain-text-splitters transformers pypdf +faiss-cpu +scikit-learn diff --git a/src/agrag/modules/vector_db/faiss.py b/src/agrag/modules/vector_db/faiss.py new file mode 100644 index 00000000..5f9d7469 --- /dev/null +++ b/src/agrag/modules/vector_db/faiss.py @@ -0,0 +1,41 @@ +import logging +from typing import List, Union + +import faiss +import numpy as np +import torch + +from agrag.modules.vector_db.utils import pad_embeddings + +logger = logging.getLogger("rag-logger") + + +def construct_faiss_index(embeddings: List[torch.Tensor], gpu: bool) -> faiss.IndexFlatL2: + """ + Constructs a FAISS index and stores the embeddings. + + Parameters: + ---------- + embeddings : List[torch.Tensor] + A list of embeddings to be stored in the FAISS index. + + Returns: + ------- + Union[faiss.IndexFlatL2, faiss.GpuIndexFlatL2] + The constructed FAISS index. + """ + d = embeddings[0].shape[-1] # dimension of the vectors + logger.info(f"Constructing FAISS index with dimension: {d}") + + index = faiss.IndexFlatL2(d) # Flat (CPU) index, L2 distance + + if gpu: + res = faiss.StandardGpuResources() + index = faiss.index_cpu_to_gpu(res, 0, index) + logger.info("Using FAISS GPU index") + + embeddings_array = np.array(embeddings) + index.add(embeddings_array) + logger.info(f"Stored {embeddings_array.shape[0]} embeddings in the FAISS index") + + return index diff --git a/src/agrag/modules/vector_db/utils.py b/src/agrag/modules/vector_db/utils.py new file mode 100644 index 00000000..0ca8db00 --- /dev/null +++ b/src/agrag/modules/vector_db/utils.py @@ -0,0 +1,59 @@ +import logging +from typing import List + +import torch +from sklearn.metrics.pairwise import cosine_similarity + +logger = logging.getLogger("rag-logger") + + +def remove_duplicates(embeddings: List[torch.Tensor], similarity_threshold: float) -> List[torch.Tensor]: + """ + Removes duplicate embeddings based on cosine similarity. + + Parameters: + ---------- + embeddings : List[torch.Tensor] + A list of embeddings to be deduplicated. + + Returns: + ------- + List[torch.Tensor] + A list of deduplicated embeddings. + """ + if len(embeddings) <= 1: + return embeddings + + embeddings_array = embeddings.numpy().reshape(len(embeddings), -1) + similarity_matrix = cosine_similarity(embeddings_array) + + to_remove = set() + for i in range(len(similarity_matrix)): + for j in range(i + 1, len(similarity_matrix)): + if similarity_matrix[i, j] > similarity_threshold: + to_remove.add(j) + + deduplicated_embeddings = [embedding for i, embedding in enumerate(embeddings) if i not in to_remove] + logger.info(f"Removed {len(to_remove)} duplicate embeddings") + return deduplicated_embeddings + + +def pad_embeddings(embeddings: List[torch.Tensor]) -> torch.Tensor: + """ + Pads embeddings to ensure they have the same length. + + Parameters: + ---------- + embeddings : List[torch.Tensor] + A list of embeddings to be padded. + + Returns: + ------- + torch.Tensor + A tensor containing the padded embeddings. + """ + max_len = max(embedding.shape[1] for embedding in embeddings) + padded_embeddings = [ + torch.nn.functional.pad(embedding, (0, 0, 0, max_len - embedding.shape[1])) for embedding in embeddings + ] + return torch.cat(padded_embeddings, dim=0).view(len(padded_embeddings), -1) diff --git a/src/agrag/modules/vector_db/vector_database.py b/src/agrag/modules/vector_db/vector_database.py index 209afd96..ad374426 100644 --- a/src/agrag/modules/vector_db/vector_database.py +++ b/src/agrag/modules/vector_db/vector_database.py @@ -1,6 +1,72 @@ +import logging +from typing import Any, List, Union + +import faiss +import torch + +from agrag.modules.vector_db.faiss import construct_faiss_index +from agrag.modules.vector_db.utils import pad_embeddings, remove_duplicates + +logger = logging.getLogger("rag-logger") + + class VectorDatabaseModule: - def __init__(self): - pass + """ + A class used to construct and manage a vector database for storing embeddings. + + Attributes: + ---------- + db_type : str + The type of vector database to use (default is 'faiss'). + index : Any + The vector database index. + params : dict + Additional parameters for configuring the Vector DB index. + similarity_threshold : float + The threshold for considering embeddings as duplicates based on cosine similarity. + + Methods: + ------- + construct_vector_database(embeddings: List[torch.Tensor]) -> Any: + Constructs the vector database and stores the embeddings. + """ + + def __init__(self, db_type: str = "faiss", params: dict = None, similarity_threshold: float = 0.95) -> None: + """ + Initializes the VectorDatabaseModule with the specified type and parameters. + + Parameters: + ---------- + db_type : str, optional + The type of vector database to use (default is 'faiss'). + params : dict, optional + Additional parameters for configuring the FAISS index. + similarity_threshold : float, optional + The threshold for considering embeddings as duplicates based on cosine similarity (default is 0.95). + """ + self.db_type = db_type + self.params = params if params is not None else {} + self.similarity_threshold = similarity_threshold + self.index = None + + def construct_vector_database(self, embeddings: List[torch.Tensor]) -> Union[faiss.IndexFlatL2, Any]: + """ + Constructs the vector database and stores the embeddings. + + Parameters: + ---------- + embeddings : List[torch.Tensor] + A list of embeddings to be stored in the vector database. - def construct_vector_database(self, embeddings): - pass + Returns: + ------- + Union[faiss.IndexFlatL2, Any] + The constructed vector database index. + """ + embeddings = pad_embeddings(embeddings) + embeddings = remove_duplicates(embeddings, self.similarity_threshold) + if self.db_type == "faiss": + self.index = construct_faiss_index(embeddings, self.params.get("gpu", False)) + else: + raise ValueError(f"Unsupported database type: {self.db_type}") + return self.index