-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f8461eb
commit 6f67864
Showing
5 changed files
with
175 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ langchain-community | |
langchain-text-splitters | ||
transformers | ||
pypdf | ||
faiss-cpu | ||
scikit-learn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
from typing import List, Union | ||
|
||
import faiss | ||
import numpy as np | ||
import torch | ||
|
||
from agrag.modules.vector_db.utils import pad_embeddings | ||
|
||
logger = logging.getLogger("rag-logger") | ||
|
||
|
||
def construct_faiss_index(embeddings: List[torch.Tensor], gpu: bool) -> faiss.IndexFlatL2: | ||
""" | ||
Constructs a FAISS index and stores the embeddings. | ||
Parameters: | ||
---------- | ||
embeddings : List[torch.Tensor] | ||
A list of embeddings to be stored in the FAISS index. | ||
Returns: | ||
------- | ||
Union[faiss.IndexFlatL2, faiss.GpuIndexFlatL2] | ||
The constructed FAISS index. | ||
""" | ||
d = embeddings[0].shape[-1] # dimension of the vectors | ||
logger.info(f"Constructing FAISS index with dimension: {d}") | ||
|
||
index = faiss.IndexFlatL2(d) # Flat (CPU) index, L2 distance | ||
|
||
if gpu: | ||
res = faiss.StandardGpuResources() | ||
index = faiss.index_cpu_to_gpu(res, 0, index) | ||
logger.info("Using FAISS GPU index") | ||
|
||
embeddings_array = np.array(embeddings) | ||
index.add(embeddings_array) | ||
logger.info(f"Stored {embeddings_array.shape[0]} embeddings in the FAISS index") | ||
|
||
return index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import logging | ||
from typing import List | ||
|
||
import torch | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
logger = logging.getLogger("rag-logger") | ||
|
||
|
||
def remove_duplicates(embeddings: List[torch.Tensor], similarity_threshold: float) -> List[torch.Tensor]: | ||
""" | ||
Removes duplicate embeddings based on cosine similarity. | ||
Parameters: | ||
---------- | ||
embeddings : List[torch.Tensor] | ||
A list of embeddings to be deduplicated. | ||
Returns: | ||
------- | ||
List[torch.Tensor] | ||
A list of deduplicated embeddings. | ||
""" | ||
if len(embeddings) <= 1: | ||
return embeddings | ||
|
||
embeddings_array = embeddings.numpy().reshape(len(embeddings), -1) | ||
similarity_matrix = cosine_similarity(embeddings_array) | ||
|
||
to_remove = set() | ||
for i in range(len(similarity_matrix)): | ||
for j in range(i + 1, len(similarity_matrix)): | ||
if similarity_matrix[i, j] > similarity_threshold: | ||
to_remove.add(j) | ||
|
||
deduplicated_embeddings = [embedding for i, embedding in enumerate(embeddings) if i not in to_remove] | ||
logger.info(f"Removed {len(to_remove)} duplicate embeddings") | ||
return deduplicated_embeddings | ||
|
||
|
||
def pad_embeddings(embeddings: List[torch.Tensor]) -> torch.Tensor: | ||
""" | ||
Pads embeddings to ensure they have the same length. | ||
Parameters: | ||
---------- | ||
embeddings : List[torch.Tensor] | ||
A list of embeddings to be padded. | ||
Returns: | ||
------- | ||
torch.Tensor | ||
A tensor containing the padded embeddings. | ||
""" | ||
max_len = max(embedding.shape[1] for embedding in embeddings) | ||
padded_embeddings = [ | ||
torch.nn.functional.pad(embedding, (0, 0, 0, max_len - embedding.shape[1])) for embedding in embeddings | ||
] | ||
return torch.cat(padded_embeddings, dim=0).view(len(padded_embeddings), -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,72 @@ | ||
import logging | ||
from typing import Any, List, Union | ||
|
||
import faiss | ||
import torch | ||
|
||
from agrag.modules.vector_db.faiss import construct_faiss_index | ||
from agrag.modules.vector_db.utils import pad_embeddings, remove_duplicates | ||
|
||
logger = logging.getLogger("rag-logger") | ||
|
||
|
||
class VectorDatabaseModule: | ||
def __init__(self): | ||
pass | ||
""" | ||
A class used to construct and manage a vector database for storing embeddings. | ||
Attributes: | ||
---------- | ||
db_type : str | ||
The type of vector database to use (default is 'faiss'). | ||
index : Any | ||
The vector database index. | ||
params : dict | ||
Additional parameters for configuring the Vector DB index. | ||
similarity_threshold : float | ||
The threshold for considering embeddings as duplicates based on cosine similarity. | ||
Methods: | ||
------- | ||
construct_vector_database(embeddings: List[torch.Tensor]) -> Any: | ||
Constructs the vector database and stores the embeddings. | ||
""" | ||
|
||
def __init__(self, db_type: str = "faiss", params: dict = None, similarity_threshold: float = 0.95) -> None: | ||
""" | ||
Initializes the VectorDatabaseModule with the specified type and parameters. | ||
Parameters: | ||
---------- | ||
db_type : str, optional | ||
The type of vector database to use (default is 'faiss'). | ||
params : dict, optional | ||
Additional parameters for configuring the FAISS index. | ||
similarity_threshold : float, optional | ||
The threshold for considering embeddings as duplicates based on cosine similarity (default is 0.95). | ||
""" | ||
self.db_type = db_type | ||
self.params = params if params is not None else {} | ||
self.similarity_threshold = similarity_threshold | ||
self.index = None | ||
|
||
def construct_vector_database(self, embeddings: List[torch.Tensor]) -> Union[faiss.IndexFlatL2, Any]: | ||
""" | ||
Constructs the vector database and stores the embeddings. | ||
Parameters: | ||
---------- | ||
embeddings : List[torch.Tensor] | ||
A list of embeddings to be stored in the vector database. | ||
def construct_vector_database(self, embeddings): | ||
pass | ||
Returns: | ||
------- | ||
Union[faiss.IndexFlatL2, Any] | ||
The constructed vector database index. | ||
""" | ||
embeddings = pad_embeddings(embeddings) | ||
embeddings = remove_duplicates(embeddings, self.similarity_threshold) | ||
if self.db_type == "faiss": | ||
self.index = construct_faiss_index(embeddings, self.params.get("gpu", False)) | ||
else: | ||
raise ValueError(f"Unsupported database type: {self.db_type}") | ||
return self.index |