From 89bf9743ec5358241ac96595d05a3ffbba075739 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:09:59 -0800 Subject: [PATCH] type changes for weaviate and qdrant files --- lotus/vector_store/qdrant_vs.py | 6 +++--- lotus/vector_store/weaviate_vs.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index c672727..28a7c7d 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -113,7 +113,7 @@ def __call__( indices = [] for result in results: - indices.append(result.payload["doc_id"]) + indices.append(result.id) distances.append(result.score) # Qdrant returns cosine similarity directly # Pad results if fewer than K matches @@ -125,8 +125,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 44eb746..9f18ce0 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -103,7 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) - response.objects[0].metadata.distance + response.objects[0].uuid results.append(response) # Process results into expected format @@ -116,7 +116,7 @@ def __call__(self, distances:List[float] = [] indices = [] for obj in objects: - indices.append(obj.properties.get('content')) + indices.append(obj.uuid) # Convert cosine distance to similarity score distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 distances.append(1 - distance) # Convert distance to similarity