Skip to content

Commit

Permalink
made type changes to weaviate file
Browse files Browse the repository at this point in the history
  • Loading branch information
AmoghTantradi committed Jan 15, 2025
1 parent 9f257f7 commit 99cb535
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lotus/vector_store/vs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __call__(self,
pass

@abstractmethod
def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]:
def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]:
pass

def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]:
Expand Down
11 changes: 5 additions & 6 deletions lotus/vector_store/weaviate_vs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Union
from typing import Any, Callable, List, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -103,6 +103,7 @@ def __call__(self,
limit=K,
return_metadata=MetadataQuery(distance=True)
))
response.objects[0].metadata.distance
results.append(response)

# Process results into expected format
Expand All @@ -112,12 +113,12 @@ def __call__(self,
for result in results:
objects = result.objects

distances = []
distances:List[float] = []
indices = []
for obj in objects:
indices.append(obj.properties.get('content'))
# Convert cosine distance to similarity score
distance = obj.metadata.distance
distance:float = obj.metadata.distance
distances.append(1 - distance) # Convert distance to similarity

# Pad results if fewer than K matches
Expand All @@ -133,21 +134,19 @@ def __call__(self,
indices=np.array(all_indices, dtype=np.int64).tolist()
)

def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]:
def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]:
"""Retrieve vectors for specific document IDs"""
collection = self.client.collections.get(collection_name)

# Query for documents with specific doc_ids
vectors = []

response = collection.query.fetch_objects_by_ids(ids=ids)
for id in ids:
response = collection.query.fetch_object_by_id(uuid=id)
if response:
vectors.append(response.vector)
else:
raise ValueError(f'{id} does not exist in {collection_name}')

return np.array(vectors, dtype=np.float64)


0 comments on commit 99cb535

Please sign in to comment.