Skip to content

Commit

Permalink
added extra refactoring and added implementations for qdrant and chro…
Browse files Browse the repository at this point in the history
…ma_vs
  • Loading branch information
AmoghTantradi committed Jan 14, 2025
1 parent 3e89b5f commit f2937ad
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 20 deletions.
147 changes: 139 additions & 8 deletions lotus/vector_store/chroma_vs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,147 @@
from typing import Any, Callable, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from PIL import Image
from tqdm import tqdm

from lotus.types import RMOutput
from lotus.vector_store.vs import VS

try:
import chromadb
from chromadb.api import Collection
except ImportError as err:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`"
) from err

class ChromaVS(VS):
def __init__(self):
def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
"""Initialize with ChromaDB client and embedding model"""
super().__init__(embedding_model)
self.client = client
self.collection: Collection | None = None
self.collection_name = None
self.max_batch_size = max_batch_size


def index(self, docs: pd.Series, collection_name: str):
"""Create a collection and add documents with their embeddings"""
self.collection_name = collection_name

# Create collection without embedding function (we'll provide embeddings directly)
self.collection = self.client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency
)

# Convert docs to list if it's a pandas Series
docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs

# Generate embeddings
embeddings = self._batch_embed(docs_list)

# Prepare documents for addition
ids = [str(i) for i in range(len(docs_list))]
metadatas = [{"doc_id": i} for i in range(len(docs_list))]

# Add documents in batches
batch_size = 100
for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"):
end_idx = min(i + batch_size, len(docs_list))
self.collection.add(
ids=ids[i:end_idx],
documents=docs_list[i:end_idx],
embeddings=embeddings[i:end_idx].tolist(),
metadatas=metadatas[i:end_idx]
)

def load_index(self, collection_name: str):
"""Load an existing collection"""
try:
import chromadb
except ImportError:
chromadb = None
self.collection = self.client.get_collection(collection_name)
self.collection_name = collection_name
except ValueError as e:
raise ValueError(f"Collection {collection_name} not found") from e

def __call__(
self,
queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]],
K: int,
**kwargs: dict[str, Any]
) -> RMOutput:
"""Perform vector search using ChromaDB"""
if self.collection is None:
raise ValueError("No collection loaded. Call load_index first.")

# Convert single query to list
if isinstance(queries, (str, Image.Image)):
queries = [queries]

# Handle numpy array queries (pre-computed vectors)
if isinstance(queries, np.ndarray):
query_vectors = queries
else:
# Convert queries to list if needed
if isinstance(queries, pd.Series):
queries = queries.tolist()
# Create embeddings for text queries
query_vectors = self._batch_embed(queries)

# Perform searches
all_distances = []
all_indices = []

if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`",
for query_vector in query_vectors:
results = self.collection.query(
query_embeddings=[query_vector.tolist()],
n_results=K,
include=['metadatas', 'distances']
)
pass

# Extract distances and indices
distances = []
indices = []

if results['metadatas']:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
indices.append(metadata['doc_id'])
# ChromaDB returns squared L2 distances, convert to cosine similarity
# similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity
distances.append(1 - (distance / 2))

# Pad results if fewer than K matches
while len(indices) < K:
indices.append(-1)
distances.append(0.0)

all_distances.append(distances)
all_indices.append(indices)

return RMOutput(
distances=np.array(all_distances, dtype=np.float32),
indices=np.array(all_indices, dtype=np.int64)
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.collection is None or self.collection_name != collection_name:
self.load_index(collection_name)

# Convert integer ids to strings for ChromaDB
str_ids = [str(id) for id in ids]

# Get embeddings from ChromaDB
results = self.collection.get(
ids=str_ids,
include=['embeddings']
)

if not results['embeddings']:
raise ValueError("No vectors found for the given ids")

return np.array(results['embeddings'], dtype=np.float64)


2 changes: 1 addition & 1 deletion lotus/vector_store/pinecone_vs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
) from err

class PineconeVS(VS):
def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
"""Initialize Pinecone client with API key and environment"""
super().__init__(embedding_model)
self.pinecone = Pinecone(api_key=api_key)
Expand Down
155 changes: 146 additions & 9 deletions lotus/vector_store/qdrant_vs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,153 @@
from typing import Any, Callable, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from PIL import Image
from tqdm import tqdm

from lotus.types import RMOutput
from lotus.vector_store.vs import VS

try:
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
except ImportError as err:
raise ImportError("Please install the qdrant client") from err

class QdrantVS(VS):
def __init__(self):
try:
import qdrant_client
except ImportError:
qdrant_client = None
def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
"""Initialize with Qdrant client and embedding model"""
super().__init__(embedding_model) # Fixed the super() call syntax
self.client = client
self.max_batch_size = max_batch_size

def index(self, docs: pd.Series, collection_name: str):
"""Create a collection and add documents with their embeddings"""
self.collection_name = collection_name

# Get sample embedding to determine vector dimension
sample_embedding = self._embed([docs.iloc[0]])
dimension = sample_embedding.shape[1]

# Create collection if it doesn't exist
if not self.client.collection_exists(collection_name):
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=dimension, distance=Distance.COSINE)
)

# Convert docs to list if it's a pandas Series
docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs

# Generate embeddings
embeddings = self._batch_embed(docs_list)

# Prepare points for upload
points = []
for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)):
points.append(
PointStruct(
id=idx,
vector=embedding.tolist(),
payload={
"content": doc,
"doc_id": idx
}
)
)

# Upload in batches
batch_size = 100
for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"):
batch = points[i:i + batch_size]
self.client.upsert(
collection_name=collection_name,
points=batch
)

def load_index(self, collection_name: str):
"""Set the collection name to use"""
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection {collection_name} not found")
self.collection_name = collection_name

def __call__(
self,
queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]],
K: int,
**kwargs: dict[str, Any]
) -> RMOutput:
"""Perform vector search using Qdrant"""
if self.collection_name is None:
raise ValueError("No collection loaded. Call load_index first.")

if qdrant_client is None:
raise ImportError(
"The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`",
# Convert single query to list
if isinstance(queries, (str, Image.Image)):
queries = [queries]

# Handle numpy array queries (pre-computed vectors)
if isinstance(queries, np.ndarray):
query_vectors = queries
else:
# Convert queries to list if needed
if isinstance(queries, pd.Series):
queries = queries.tolist()
# Create embeddings for text queries
query_vectors = self._batch_embed(queries)

# Perform searches
all_distances = []
all_indices = []

for query_vector in query_vectors:
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector.tolist(),
limit=K,
with_payload=True
)
pass

# Extract distances and indices
distances = []
indices = []

for result in results:
indices.append(result.payload["doc_id"])
distances.append(result.score) # Qdrant returns cosine similarity directly

# Pad results if fewer than K matches
while len(indices) < K:
indices.append(-1)
distances.append(0.0)

all_distances.append(distances)
all_indices.append(indices)

return RMOutput(
distances=np.array(all_distances, dtype=np.float32),
indices=np.array(all_indices, dtype=np.int64)
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.collection_name != collection_name:
self.load_index(collection_name)

# Fetch points from Qdrant
points = self.client.retrieve(
collection_name=collection_name,
ids=ids,
with_vectors=True,
with_payload=False
)

# Extract and return vectors
vectors = []
for point in points:
if point.vector is not None:
vectors.append(point.vector)
else:
raise ValueError(f"Vector not found for id {point.id}")

return np.array(vectors, dtype=np.float64)
3 changes: 2 additions & 1 deletion lotus/vector_store/vs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class VS(ABC):
def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None:
self.collection_name: str | None = None
self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model
pass
self.max_batch_size:int = 64


@abstractmethod
def index(self, docs: pd.Series, collection_name: str):
Expand Down
2 changes: 1 addition & 1 deletion lotus/vector_store/weaviate_vs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
raise ImportError("Please install the weaviate client") from err

class WeaviateVS(VS):
def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
"""Initialize with Weaviate client and embedding model"""
super().__init__(embedding_model)
self.client = weaviate_client
Expand Down

0 comments on commit f2937ad

Please sign in to comment.