diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index d05b1e4..3735f79 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Mapping, Union import numpy as np import pandas as pd @@ -11,14 +11,16 @@ try: import chromadb + from chromadb import ClientAPI from chromadb.api import Collection + from chromadb.api.types import IncludeEnum except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" ) from err class ChromaVS(VS): - def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client @@ -45,7 +47,7 @@ def index(self, docs: pd.Series, collection_name: str): # Prepare documents for addition ids = [str(i) for i in range(len(docs_list))] - metadatas = [{"doc_id": i} for i in range(len(docs_list))] + metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] # Add documents in batches batch_size = 100 @@ -98,14 +100,14 @@ def __call__( results = self.collection.query( query_embeddings=[query_vector.tolist()], n_results=K, - include=['metadatas', 'distances'] + include=[IncludeEnum.metadatas, IncludeEnum.distances] ) # Extract distances and indices distances = [] indices = [] - if results['metadatas']: + if results['metadatas'] and results['distances']: for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): indices.append(metadata['doc_id']) # ChromaDB returns squared L2 distances, convert to cosine similarity @@ -121,8 +123,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: @@ -130,13 +132,18 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra if self.collection is None or self.collection_name != collection_name: self.load_index(collection_name) + + if self.collection is None: # Add this check after load_index + raise ValueError(f"Failed to load collection {collection_name}") + + # Convert integer ids to strings for ChromaDB str_ids = [str(id) for id in ids] # Get embeddings from ChromaDB results = self.collection.get( ids=str_ids, - include=['embeddings'] + include=[IncludeEnum.embeddings] ) if not results['embeddings']: