From f2937adcd0ee6f02589e922e0b426a95e8410d0a Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 13:37:21 -0800 Subject: [PATCH] added extra refactoring and added implementations for qdrant and chroma_vs --- lotus/vector_store/chroma_vs.py | 147 ++++++++++++++++++++++++++-- lotus/vector_store/pinecone_vs.py | 2 +- lotus/vector_store/qdrant_vs.py | 155 ++++++++++++++++++++++++++++-- lotus/vector_store/vs.py | 3 +- lotus/vector_store/weaviate_vs.py | 2 +- 5 files changed, 289 insertions(+), 20 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 298f05e..d05b1e4 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,16 +1,147 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + import chromadb + from chromadb.api import Collection +except ImportError as err: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" + ) from err class ChromaVS(VS): - def __init__(self): + def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with ChromaDB client and embedding model""" + super().__init__(embedding_model) + self.client = client + self.collection: Collection | None = None + self.collection_name = None + self.max_batch_size = max_batch_size + + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare documents for addition + ids = [str(i) for i in range(len(docs_list))] + metadatas = [{"doc_id": i} for i in range(len(docs_list))] + + # Add documents in batches + batch_size = 100 + for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): + end_idx = min(i + batch_size, len(docs_list)) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + + def load_index(self, collection_name: str): + """Load an existing collection""" try: - import chromadb - except ImportError: - chromadb = None + self.collection = self.client.get_collection(collection_name) + self.collection_name = collection_name + except ValueError as e: + raise ValueError(f"Collection {collection_name} not found") from e + + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using ChromaDB""" + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + # Perform searches + all_distances = [] + all_indices = [] - if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + for query_vector in query_vectors: + results = self.collection.query( + query_embeddings=[query_vector.tolist()], + n_results=K, + include=['metadatas', 'distances'] ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + if results['metadatas']: + for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): + indices.append(metadata['doc_id']) + # ChromaDB returns squared L2 distances, convert to cosine similarity + # similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity + distances.append(1 - (distance / 2)) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection is None or self.collection_name != collection_name: + self.load_index(collection_name) + + # Convert integer ids to strings for ChromaDB + str_ids = [str(id) for id in ids] + + # Get embeddings from ChromaDB + results = self.collection.get( + ids=str_ids, + include=['embeddings'] + ) + + if not results['embeddings']: + raise ValueError("No vectors found for the given ids") + + return np.array(results['embeddings'], dtype=np.float64) + + \ No newline at end of file diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 19e7f81..da8cc34 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -17,7 +17,7 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 5ded2d7..c672727 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,16 +1,153 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, PointStruct, VectorParams +except ImportError as err: + raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self): - try: - import qdrant_client - except ImportError: - qdrant_client = None + def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with Qdrant client and embedding model""" + super().__init__(embedding_model) # Fixed the super() call syntax + self.client = client + self.max_batch_size = max_batch_size + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + dimension = sample_embedding.shape[1] + + # Create collection if it doesn't exist + if not self.client.collection_exists(collection_name): + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare points for upload + points = [] + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + points.append( + PointStruct( + id=idx, + vector=embedding.tolist(), + payload={ + "content": doc, + "doc_id": idx + } + ) + ) + + # Upload in batches + batch_size = 100 + for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): + batch = points[i:i + batch_size] + self.client.upsert( + collection_name=collection_name, + points=batch + ) + + def load_index(self, collection_name: str): + """Set the collection name to use""" + if not self.client.collection_exists(collection_name): + raise ValueError(f"Collection {collection_name} not found") + self.collection_name = collection_name + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Qdrant""" + if self.collection_name is None: + raise ValueError("No collection loaded. Call load_index first.") - if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + results = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector.tolist(), + limit=K, + with_payload=True ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + for result in results: + indices.append(result.payload["doc_id"]) + distances.append(result.score) # Qdrant returns cosine similarity directly + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection_name != collection_name: + self.load_index(collection_name) + + # Fetch points from Qdrant + points = self.client.retrieve( + collection_name=collection_name, + ids=ids, + with_vectors=True, + with_payload=False + ) + + # Extract and return vectors + vectors = [] + for point in points: + if point.vector is not None: + vectors.append(point.vector) + else: + raise ValueError(f"Vector not found for id {point.id}") + + return np.array(vectors, dtype=np.float64) \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 89a7bce..0f37a1a 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -18,7 +18,8 @@ class VS(ABC): def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: self.collection_name: str | None = None self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model - pass + self.max_batch_size:int = 64 + @abstractmethod def index(self, docs: pd.Series, collection_name: str): diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 364bca3..e1e957b 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -17,7 +17,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client