Skip to content

Commit

Permalink
type checks all pass locally
Browse files Browse the repository at this point in the history
  • Loading branch information
AmoghTantradi committed Jan 16, 2025
1 parent 0621b9b commit b568d1e
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions lotus/vector_store/chroma_vs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Union
from typing import Any, Callable, Mapping, Union

import numpy as np
import pandas as pd
Expand All @@ -11,14 +11,16 @@

try:
import chromadb
from chromadb import ClientAPI
from chromadb.api import Collection
from chromadb.api.types import IncludeEnum
except ImportError as err:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`"
) from err

class ChromaVS(VS):
def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64):
"""Initialize with ChromaDB client and embedding model"""
super().__init__(embedding_model)
self.client = client
Expand All @@ -45,7 +47,7 @@ def index(self, docs: pd.Series, collection_name: str):

# Prepare documents for addition
ids = [str(i) for i in range(len(docs_list))]
metadatas = [{"doc_id": i} for i in range(len(docs_list))]
metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))]

# Add documents in batches
batch_size = 100
Expand Down Expand Up @@ -98,14 +100,14 @@ def __call__(
results = self.collection.query(
query_embeddings=[query_vector.tolist()],
n_results=K,
include=['metadatas', 'distances']
include=[IncludeEnum.metadatas, IncludeEnum.distances]
)

# Extract distances and indices
distances = []
indices = []

if results['metadatas']:
if results['metadatas'] and results['distances']:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
indices.append(metadata['doc_id'])
# ChromaDB returns squared L2 distances, convert to cosine similarity
Expand All @@ -121,22 +123,27 @@ def __call__(
all_indices.append(indices)

return RMOutput(
distances=np.array(all_distances, dtype=np.float32),
indices=np.array(all_indices, dtype=np.int64)
distances=np.array(all_distances, dtype=np.float32).tolist(),
indices=np.array(all_indices, dtype=np.int64).tolist()
)

def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
if self.collection is None or self.collection_name != collection_name:
self.load_index(collection_name)


if self.collection is None: # Add this check after load_index
raise ValueError(f"Failed to load collection {collection_name}")


# Convert integer ids to strings for ChromaDB
str_ids = [str(id) for id in ids]

# Get embeddings from ChromaDB
results = self.collection.get(
ids=str_ids,
include=['embeddings']
include=[IncludeEnum.embeddings]
)

if not results['embeddings']:
Expand Down

0 comments on commit b568d1e

Please sign in to comment.