Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding implementations for vector stores #79

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e3abd90
initial scaffolding for adding vector store / vector database integra…
AmoghTantradi Jan 12, 2025
bd1e8fd
fixed linting, ruff checks pass
AmoghTantradi Jan 12, 2025
880c31f
added changes to requirements.txt file and added additional abstract …
AmoghTantradi Jan 12, 2025
7b5dfd3
refactored
AmoghTantradi Jan 12, 2025
08dfaba
added tests for clustering and filtering
AmoghTantradi Jan 13, 2025
f3a82c1
made edits to test_filter
AmoghTantradi Jan 13, 2025
fc62846
added implementations for weaviate and pinecone vs
AmoghTantradi Jan 14, 2025
3e89b5f
fixed merge conflicts
AmoghTantradi Jan 14, 2025
f2937ad
added extra refactoring and added implementations for qdrant and chro…
AmoghTantradi Jan 14, 2025
a4c7418
fixed some type errors
AmoghTantradi Jan 14, 2025
1357fb3
made further corrections
AmoghTantradi Jan 15, 2025
c76b658
edit uuid type
AmoghTantradi Jan 15, 2025
9f257f7
changed uuid type
AmoghTantradi Jan 15, 2025
99cb535
made type changes to weaviate file
AmoghTantradi Jan 15, 2025
3c8a742
made another change
AmoghTantradi Jan 15, 2025
ccd9e48
typecheck passes for weaviate?
AmoghTantradi Jan 15, 2025
89bf974
type changes for weaviate and qdrant files
AmoghTantradi Jan 16, 2025
a76adb7
made changes to weaviate file
AmoghTantradi Jan 16, 2025
c3e0f0c
made changes to weaviate file
AmoghTantradi Jan 16, 2025
1782281
fixed pinecone type errors
AmoghTantradi Jan 16, 2025
0621b9b
fixed pinecone type errors
AmoghTantradi Jan 16, 2025
b568d1e
type checks all pass locally
AmoghTantradi Jan 16, 2025
9b33a1f
fixed linting errors
AmoghTantradi Jan 16, 2025
820f3be
made refactors to allow for testing
AmoghTantradi Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/tests/rm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def setup_models():

for model_name in ENABLED_MODEL_NAMES:
models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name)


return models


@pytest.fixture(scope='session')
def setup_vs():
pass

################################################################################
# RM Only Tests
################################################################################
Expand Down
153 changes: 145 additions & 8 deletions lotus/vector_store/chroma_vs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,153 @@
from typing import Any, Mapping, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from PIL import Image
from tqdm import tqdm

from lotus.types import RMOutput
from lotus.vector_store.vs import VS

try:
from chromadb import ClientAPI
from chromadb.api import Collection
from chromadb.api.types import IncludeEnum
except ImportError as err:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`"
) from err

class ChromaVS(VS):
def __init__(self):
def __init__(self, client: ClientAPI, embedding_model: str, max_batch_size: int = 64):
"""Initialize with ChromaDB client and embedding model"""
super().__init__(embedding_model)
self.client = client
self.collection: Collection | None = None
self.collection_name = None
self.max_batch_size = max_batch_size


def index(self, docs: pd.Series, collection_name: str):
"""Create a collection and add documents with their embeddings"""
self.collection_name = collection_name

# Create collection without embedding function (we'll provide embeddings directly)
self.collection = self.client.create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency
)

# Convert docs to list if it's a pandas Series
docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs

# Generate embeddings
embeddings = self._batch_embed(docs_list)

# Prepare documents for addition
ids = [str(i) for i in range(len(docs_list))]
metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))]

# Add documents in batches
batch_size = 100
for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"):
end_idx = min(i + batch_size, len(docs_list))
self.collection.add(
ids=ids[i:end_idx],
documents=docs_list[i:end_idx],
embeddings=embeddings[i:end_idx].tolist(),
metadatas=metadatas[i:end_idx]
)

def load_index(self, collection_name: str):
"""Load an existing collection"""
try:
import chromadb
except ImportError:
chromadb = None
self.collection = self.client.get_collection(collection_name)
self.collection_name = collection_name
except ValueError as e:
raise ValueError(f"Collection {collection_name} not found") from e

def __call__(
self,
queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]],
K: int,
**kwargs: dict[str, Any]
) -> RMOutput:
"""Perform vector search using ChromaDB"""
if self.collection is None:
raise ValueError("No collection loaded. Call load_index first.")

# Convert single query to list
if isinstance(queries, (str, Image.Image)):
queries = [queries]

# Handle numpy array queries (pre-computed vectors)
if isinstance(queries, np.ndarray):
query_vectors = queries
else:
# Convert queries to list if needed
if isinstance(queries, pd.Series):
queries = queries.tolist()
# Create embeddings for text queries
query_vectors = self._batch_embed(queries)

if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`",
# Perform searches
all_distances = []
all_indices = []

for query_vector in query_vectors:
results = self.collection.query(
query_embeddings=[query_vector.tolist()],
n_results=K,
include=[IncludeEnum.metadatas, IncludeEnum.distances]
)
pass

# Extract distances and indices
distances = []
indices = []

if results['metadatas'] and results['distances']:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
indices.append(metadata['doc_id'])
# ChromaDB returns squared L2 distances, convert to cosine similarity
# similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity
distances.append(1 - (distance / 2))

# Pad results if fewer than K matches
while len(indices) < K:
indices.append(-1)
distances.append(0.0)

all_distances.append(distances)
all_indices.append(indices)

return RMOutput(
distances=np.array(all_distances, dtype=np.float32).tolist(),
indices=np.array(all_indices, dtype=np.int64).tolist()
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.collection is None or self.collection_name != collection_name:
self.load_index(collection_name)


if self.collection is None: # Add this check after load_index
raise ValueError(f"Failed to load collection {collection_name}")


# Convert integer ids to strings for ChromaDB
str_ids = [str(id) for id in ids]

# Get embeddings from ChromaDB
results = self.collection.get(
ids=str_ids,
include=[IncludeEnum.embeddings]
)

if not results['embeddings']:
raise ValueError("No vectors found for the given ids")

return np.array(results['embeddings'], dtype=np.float64)


159 changes: 150 additions & 9 deletions lotus/vector_store/pinecone_vs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,158 @@
from typing import Any, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from PIL import Image
from tqdm import tqdm

from lotus.types import RMOutput
from lotus.vector_store.vs import VS

try:
from pinecone import Index, Pinecone
except ImportError as err:
raise ImportError(
"The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`",
) from err

class PineconeVS(VS):
def __init__(self):
try:
import pinecone
except ImportError:
pinecone = None
def __init__(self, api_key: str, embedding_model: str, max_batch_size: int = 64):
"""Initialize Pinecone client with API key and environment"""
super().__init__(embedding_model)
self.pinecone = Pinecone(api_key=api_key)
self.pc_index:Index | None = None
self.max_batch_size = max_batch_size


def index(self, docs: pd.Series, collection_name: str):
"""Create an index and add documents to it"""
self.collection_name = collection_name

# Get sample embedding to determine vector dimension
sample_embedding = self._embed([docs.iloc[0]])
dimension = sample_embedding.shape[1]

# Check if index already exists
if collection_name not in self.pinecone.list_indexes():
# Create new index with the correct dimension
self.pinecone.create_index(
name=collection_name,
dimension=dimension,
metric="cosine"
)

# Connect to index
self.pc_index = self.pinecone.Index(collection_name)

# Convert docs to list if it's a pandas Series
docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs

# Create embeddings using the provided embedding model
embeddings = self._batch_embed(docs_list)

# Prepare vectors for upsert
vectors = []
for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)):
vectors.append({
"id": str(idx),
"values": embedding.tolist(), # Pinecone expects lists, not numpy arrays
"metadata": {
"content": doc,
"doc_id": idx
}
})

# Upsert in batches of 100
batch_size = 100
for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading to Pinecone"):
batch = vectors[i:i + batch_size]
self.pc_index.upsert(vectors=batch)

def load_index(self, collection_name: str):
"""Connect to an existing Pinecone index"""
if collection_name not in self.pinecone.list_indexes():
raise ValueError(f"Index {collection_name} not found")

self.collection_name = collection_name
self.pc_index = self.pinecone.Index(collection_name)

def __call__(
self,
queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]],
K: int,
**kwargs: dict[str, Any]
) -> RMOutput:
"""Perform vector search using Pinecone"""
if self.pc_index is None:
raise ValueError("No index loaded. Call load_index first.")

# Convert single query to list
if isinstance(queries, (str, Image.Image)):
queries = [queries]

# Handle numpy array queries (pre-computed vectors)
if isinstance(queries, np.ndarray):
query_vectors = queries
else:
# Convert queries to list if needed
if isinstance(queries, pd.Series):
queries = queries.tolist()
# Create embeddings for text queries
query_vectors = self._batch_embed(queries)

# Perform searches
all_distances = []
all_indices = []

if pinecone is None:
raise ImportError(
"The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`",
for query_vector in query_vectors:
# Query Pinecone
results = self.pc_index.query(
vector=query_vector.tolist(),
top_k=K,
include_metadata=True,
**kwargs
)

pass
# Extract distances and indices
distances = []
indices = []

for match in results.matches:
indices.append(int(match.metadata["doc_id"]))
distances.append(match.score)

# Pad results if fewer than K matches
while len(indices) < K:
indices.append(-1) # Use -1 for padding
distances.append(0.0)

all_distances.append(distances)
all_indices.append(indices)

return RMOutput(
distances=np.array(all_distances, dtype=np.float32).tolist(),
indices=np.array(all_indices, dtype=np.int64).tolist()
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.pc_index is None or self.collection_name != collection_name:
self.load_index(collection_name)

if self.pc_index is None: # Add this check after load_index
raise ValueError("Failed to initialize Pinecone index")



# Fetch vectors from Pinecone
vectors = []
for doc_id in ids:
response = self.pc_index.fetch(ids=[str(doc_id)])
if str(doc_id) in response.vectors:
vector = response.vectors[str(doc_id)].values
vectors.append(vector)
else:
raise ValueError(f"Document with id {doc_id} not found")

return np.array(vectors, dtype=np.float64)
Loading
Loading