From 99cb535ad92f0bcab42f818b8366e973dd4c8ed7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:15:43 -0800 Subject: [PATCH] made type changes to weaviate file --- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 0f37a1a..c370df6 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -41,7 +41,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 82fff0b..26585cc 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union import numpy as np import pandas as pd @@ -103,6 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) + response.objects[0].metadata.distance results.append(response) # Process results into expected format @@ -112,12 +113,12 @@ def __call__(self, for result in results: objects = result.objects - distances = [] + distances:List[float] = [] indices = [] for obj in objects: indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance = obj.metadata.distance + distance:float = obj.metadata.distance distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches @@ -133,21 +134,19 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) # Query for documents with specific doc_ids vectors = [] - response = collection.query.fetch_objects_by_ids(ids=ids) for id in ids: response = collection.query.fetch_object_by_id(uuid=id) if response: vectors.append(response.vector) else: raise ValueError(f'{id} does not exist in {collection_name}') - return np.array(vectors, dtype=np.float64)