Skip to content

Commit

Permalink
vector db module interface setup
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyash2106 committed Jun 6, 2024
1 parent f8461eb commit 6f67864
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 5 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ dependencies = [
"langchain-community",
"langchain-text-splitters",
"transformers",
"pypdf"
"pypdf",
"faiss-cpu",
"scikit-learn",
]


Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ langchain-community
langchain-text-splitters
transformers
pypdf
faiss-cpu
scikit-learn
41 changes: 41 additions & 0 deletions src/agrag/modules/vector_db/faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import List, Union

import faiss
import numpy as np
import torch

from agrag.modules.vector_db.utils import pad_embeddings

logger = logging.getLogger("rag-logger")


def construct_faiss_index(embeddings: List[torch.Tensor], gpu: bool) -> faiss.IndexFlatL2:
"""
Constructs a FAISS index and stores the embeddings.
Parameters:
----------
embeddings : List[torch.Tensor]
A list of embeddings to be stored in the FAISS index.
Returns:
-------
Union[faiss.IndexFlatL2, faiss.GpuIndexFlatL2]
The constructed FAISS index.
"""
d = embeddings[0].shape[-1] # dimension of the vectors
logger.info(f"Constructing FAISS index with dimension: {d}")

index = faiss.IndexFlatL2(d) # Flat (CPU) index, L2 distance

if gpu:
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
logger.info("Using FAISS GPU index")

embeddings_array = np.array(embeddings)
index.add(embeddings_array)
logger.info(f"Stored {embeddings_array.shape[0]} embeddings in the FAISS index")

return index
59 changes: 59 additions & 0 deletions src/agrag/modules/vector_db/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
from typing import List

import torch
from sklearn.metrics.pairwise import cosine_similarity

logger = logging.getLogger("rag-logger")


def remove_duplicates(embeddings: List[torch.Tensor], similarity_threshold: float) -> List[torch.Tensor]:
"""
Removes duplicate embeddings based on cosine similarity.
Parameters:
----------
embeddings : List[torch.Tensor]
A list of embeddings to be deduplicated.
Returns:
-------
List[torch.Tensor]
A list of deduplicated embeddings.
"""
if len(embeddings) <= 1:
return embeddings

embeddings_array = embeddings.numpy().reshape(len(embeddings), -1)
similarity_matrix = cosine_similarity(embeddings_array)

to_remove = set()
for i in range(len(similarity_matrix)):
for j in range(i + 1, len(similarity_matrix)):
if similarity_matrix[i, j] > similarity_threshold:
to_remove.add(j)

deduplicated_embeddings = [embedding for i, embedding in enumerate(embeddings) if i not in to_remove]
logger.info(f"Removed {len(to_remove)} duplicate embeddings")
return deduplicated_embeddings


def pad_embeddings(embeddings: List[torch.Tensor]) -> torch.Tensor:
"""
Pads embeddings to ensure they have the same length.
Parameters:
----------
embeddings : List[torch.Tensor]
A list of embeddings to be padded.
Returns:
-------
torch.Tensor
A tensor containing the padded embeddings.
"""
max_len = max(embedding.shape[1] for embedding in embeddings)
padded_embeddings = [
torch.nn.functional.pad(embedding, (0, 0, 0, max_len - embedding.shape[1])) for embedding in embeddings
]
return torch.cat(padded_embeddings, dim=0).view(len(padded_embeddings), -1)
74 changes: 70 additions & 4 deletions src/agrag/modules/vector_db/vector_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,72 @@
import logging
from typing import Any, List, Union

import faiss
import torch

from agrag.modules.vector_db.faiss import construct_faiss_index
from agrag.modules.vector_db.utils import pad_embeddings, remove_duplicates

logger = logging.getLogger("rag-logger")


class VectorDatabaseModule:
def __init__(self):
pass
"""
A class used to construct and manage a vector database for storing embeddings.
Attributes:
----------
db_type : str
The type of vector database to use (default is 'faiss').
index : Any
The vector database index.
params : dict
Additional parameters for configuring the Vector DB index.
similarity_threshold : float
The threshold for considering embeddings as duplicates based on cosine similarity.
Methods:
-------
construct_vector_database(embeddings: List[torch.Tensor]) -> Any:
Constructs the vector database and stores the embeddings.
"""

def __init__(self, db_type: str = "faiss", params: dict = None, similarity_threshold: float = 0.95) -> None:
"""
Initializes the VectorDatabaseModule with the specified type and parameters.
Parameters:
----------
db_type : str, optional
The type of vector database to use (default is 'faiss').
params : dict, optional
Additional parameters for configuring the FAISS index.
similarity_threshold : float, optional
The threshold for considering embeddings as duplicates based on cosine similarity (default is 0.95).
"""
self.db_type = db_type
self.params = params if params is not None else {}
self.similarity_threshold = similarity_threshold
self.index = None

def construct_vector_database(self, embeddings: List[torch.Tensor]) -> Union[faiss.IndexFlatL2, Any]:
"""
Constructs the vector database and stores the embeddings.
Parameters:
----------
embeddings : List[torch.Tensor]
A list of embeddings to be stored in the vector database.
def construct_vector_database(self, embeddings):
pass
Returns:
-------
Union[faiss.IndexFlatL2, Any]
The constructed vector database index.
"""
embeddings = pad_embeddings(embeddings)
embeddings = remove_duplicates(embeddings, self.similarity_threshold)
if self.db_type == "faiss":
self.index = construct_faiss_index(embeddings, self.params.get("gpu", False))
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
return self.index

0 comments on commit 6f67864

Please sign in to comment.