From e3abd9027618e1c267165eb4f8ce6f8a70ddac18 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 11:12:21 -0800 Subject: [PATCH 01/23] initial scaffolding for adding vector store / vector database integration --- lotus/__init__.py | 2 ++ lotus/settings.py | 2 ++ lotus/vector_store/__init__.py | 13 +++++++++++++ lotus/vector_store/chroma_vs.py | 20 ++++++++++++++++++++ lotus/vector_store/pinecone_vs.py | 18 ++++++++++++++++++ lotus/vector_store/qdrant_vs.py | 21 +++++++++++++++++++++ lotus/vector_store/vs.py | 13 +++++++++++++ lotus/vector_store/weaviate_vs.py | 15 +++++++++++++++ 8 files changed, 104 insertions(+) create mode 100644 lotus/vector_store/__init__.py create mode 100644 lotus/vector_store/chroma_vs.py create mode 100644 lotus/vector_store/pinecone_vs.py create mode 100644 lotus/vector_store/qdrant_vs.py create mode 100644 lotus/vector_store/vs.py create mode 100644 lotus/vector_store/weaviate_vs.py diff --git a/lotus/__init__.py b/lotus/__init__.py index f66cfb5c..d20f710d 100644 --- a/lotus/__init__.py +++ b/lotus/__init__.py @@ -1,6 +1,7 @@ import logging import lotus.dtype_extensions import lotus.models +import lotus.vector_store import lotus.nl_expression import lotus.templates import lotus.utils @@ -44,6 +45,7 @@ "templates", "logger", "models", + "vector_store", "utils", "dtype_extensions", ] diff --git a/lotus/settings.py b/lotus/settings.py index 99e59449..a1c54c56 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -1,4 +1,5 @@ import lotus.models +import lotus.vector_store from lotus.types import SerializationFormat # NOTE: Settings class is not thread-safe @@ -10,6 +11,7 @@ class Settings: rm: lotus.models.RM | None = None helper_lm: lotus.models.LM | None = None reranker: lotus.models.Reranker | None = None + vs: lotus.vector_store.VS | None = None # Cache settings enable_message_cache: bool = False diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py new file mode 100644 index 00000000..0f634b9a --- /dev/null +++ b/lotus/vector_store/__init__.py @@ -0,0 +1,13 @@ +from lotus.vector_store.vs import VS +from lotus.vector_store.weaviate_vs import WeaviateVS +from lotus.vector_store.pinecone_vs import PineconeVS +from lotus.vector_store.chroma_vs import ChromaVS +from lotus.vector_store.qdrant_vs import QdrantVS + +__all__ = [ + "VS", + "WeaviateVS", + "PineconeVS", + "ChromaVS", + "QdrantVS" +] \ No newline at end of file diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py new file mode 100644 index 00000000..4a70720e --- /dev/null +++ b/lotus/vector_store/chroma_vs.py @@ -0,0 +1,20 @@ +from lotus.vector_store.vs import VS + + +try: + import chromadb +except ImportError: + chromadb = None + + +if chromadb is None: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + ) + + + +class ChromaVS(VS): + + def __init__(self): + pass \ No newline at end of file diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py new file mode 100644 index 00000000..ae1f116b --- /dev/null +++ b/lotus/vector_store/pinecone_vs.py @@ -0,0 +1,18 @@ +from lotus.vector_store.vs import VS + +try: + import pinecone +except ImportError: + pinecone = None + + +if pinecone is None: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) + +class PineconeVS(VS): + + def __init__(self): + pass + diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py new file mode 100644 index 00000000..fa628e56 --- /dev/null +++ b/lotus/vector_store/qdrant_vs.py @@ -0,0 +1,21 @@ +from lotus.vector_store.vs import VS + + +try: + import qdrant_client +except ImportError: + qdrant_client = None + + +if qdrant_client is None: + raise ImportError( + "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + ) + + + + +class QdrantVS(VS): + + def __init__(self): + pass \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py new file mode 100644 index 00000000..e13f64e5 --- /dev/null +++ b/lotus/vector_store/vs.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +import pandas as pd + +class VS(ABC): + """Abstract class for vector stores.""" + + def __init__(self) -> None: + pass + + @abstractmethod + def index(self, docs: pd.Series, index_dir): + pass \ No newline at end of file diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py new file mode 100644 index 00000000..4f18f35a --- /dev/null +++ b/lotus/vector_store/weaviate_vs.py @@ -0,0 +1,15 @@ +from lotus.vector_store.vs import VS + +try: + import weaviate +except ImportError as err: + raise ImportError( + "Please install the weaviate client" + ) + +class WeaviateVS(VS): + + def __init__(self): + pass + + \ No newline at end of file From bd1e8fddf01f6931d25a88cf00018851b86a80bd Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 13:29:07 -0800 Subject: [PATCH 02/23] fixed linting, ruff checks pass --- lotus/vector_store/__init__.py | 10 ++-------- lotus/vector_store/chroma_vs.py | 7 ++----- lotus/vector_store/pinecone_vs.py | 7 +++---- lotus/vector_store/qdrant_vs.py | 10 +++------- lotus/vector_store/vs.py | 11 ++++++----- lotus/vector_store/weaviate_vs.py | 16 +++++++--------- 6 files changed, 23 insertions(+), 38 deletions(-) diff --git a/lotus/vector_store/__init__.py b/lotus/vector_store/__init__.py index 0f634b9a..34c41998 100644 --- a/lotus/vector_store/__init__.py +++ b/lotus/vector_store/__init__.py @@ -1,13 +1,7 @@ -from lotus.vector_store.vs import VS +from lotus.vector_store.vs import VS from lotus.vector_store.weaviate_vs import WeaviateVS from lotus.vector_store.pinecone_vs import PineconeVS from lotus.vector_store.chroma_vs import ChromaVS from lotus.vector_store.qdrant_vs import QdrantVS -__all__ = [ - "VS", - "WeaviateVS", - "PineconeVS", - "ChromaVS", - "QdrantVS" -] \ No newline at end of file +__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 4a70720e..e94544e0 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,10 +1,9 @@ from lotus.vector_store.vs import VS - try: import chromadb except ImportError: - chromadb = None + chromadb = None if chromadb is None: @@ -13,8 +12,6 @@ ) - class ChromaVS(VS): - def __init__(self): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index ae1f116b..19fc0967 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -3,7 +3,7 @@ try: import pinecone except ImportError: - pinecone = None + pinecone = None if pinecone is None: @@ -11,8 +11,7 @@ "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", ) -class PineconeVS(VS): +class PineconeVS(VS): def __init__(self): - pass - + pass diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index fa628e56..6fc82dae 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,10 +1,9 @@ -from lotus.vector_store.vs import VS - +from lotus.vector_store.vs import VS try: import qdrant_client except ImportError: - qdrant_client = None + qdrant_client = None if qdrant_client is None: @@ -13,9 +12,6 @@ ) - - class QdrantVS(VS): - def __init__(self): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index e13f64e5..8bfc43e8 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,13 +1,14 @@ -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod + +import pandas as pd -import pandas as pd class VS(ABC): """Abstract class for vector stores.""" def __init__(self) -> None: - pass + pass - @abstractmethod + @abstractmethod def index(self, docs: pd.Series, index_dir): - pass \ No newline at end of file + pass diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 4f18f35a..0ee5ded2 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,15 +1,13 @@ +from typing import Optional, Union + from lotus.vector_store.vs import VS try: import weaviate -except ImportError as err: - raise ImportError( - "Please install the weaviate client" - ) - -class WeaviateVS(VS): +except ImportError: + raise ImportError("Please install the weaviate client") - def __init__(self): - pass - \ No newline at end of file +class WeaviateVS(VS): + def __init__(self, weaviate_collection_name:str, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], weaviate_collection_text_key: Optional[str] = "content"): + pass From 880c31f110ec14b72ed35496643aa3f58abcc9fb Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 13:58:02 -0800 Subject: [PATCH 03/23] added changes to requirements.txt file and added additional abstract methods --- lotus/vector_store/vs.py | 18 ++++++++++++++++++ requirements.txt | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 8bfc43e8..bb4878eb 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,6 +1,12 @@ from abc import ABC, abstractmethod +from typing import Any +import numpy as np import pandas as pd +from numpy.typing import NDArray +from PIL import Image + +from lotus.types import RMOutput class VS(ABC): @@ -12,3 +18,15 @@ def __init__(self) -> None: @abstractmethod def index(self, docs: pd.Series, index_dir): pass + + @abstractmethod + def search(self, + queries: pd.Series | str | Image.Image | list | NDArray[np.float64], + K:int, + **kwargs: dict[str, Any], + ) -> RMOutput: + pass + + @abstractmethod + def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: + pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 226370bc..e645c716 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ numpy==1.26.4 pandas==2.2.2 sentence-transformers==3.0.1 tiktoken==0.7.0 -tqdm==4.66.4 \ No newline at end of file +tqdm==4.66.4 +weaviate-client==4.10.2 +pinecone==5.4.2 +chromadb==0.6.2 +qdrant-client==1.12.2 \ No newline at end of file From 7b5dfd375ffa19a9e70fc252180755d5b3e7d28c Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 14:09:02 -0800 Subject: [PATCH 04/23] refactored --- lotus/vector_store/chroma_vs.py | 21 ++++++++++----------- lotus/vector_store/pinecone_vs.py | 20 ++++++++++---------- lotus/vector_store/qdrant_vs.py | 21 ++++++++++----------- lotus/vector_store/weaviate_vs.py | 16 ++++++++-------- 4 files changed, 38 insertions(+), 40 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index e94544e0..298f05e9 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,17 +1,16 @@ from lotus.vector_store.vs import VS -try: - import chromadb -except ImportError: - chromadb = None - - -if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", - ) - class ChromaVS(VS): def __init__(self): + try: + import chromadb + except ImportError: + chromadb = None + + + if chromadb is None: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + ) pass diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 19fc0967..fdb89800 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,17 +1,17 @@ from lotus.vector_store.vs import VS -try: - import pinecone -except ImportError: - pinecone = None +class PineconeVS(VS): + def __init__(self): + try: + import pinecone + except ImportError: + pinecone = None -if pinecone is None: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", - ) + if pinecone is None: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) -class PineconeVS(VS): - def __init__(self): pass diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 6fc82dae..5ded2d7d 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,17 +1,16 @@ from lotus.vector_store.vs import VS -try: - import qdrant_client -except ImportError: - qdrant_client = None - - -if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", - ) - class QdrantVS(VS): def __init__(self): + try: + import qdrant_client + except ImportError: + qdrant_client = None + + + if qdrant_client is None: + raise ImportError( + "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + ) pass diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 0ee5ded2..3b747235 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,13 +1,13 @@ -from typing import Optional, Union - from lotus.vector_store.vs import VS -try: - import weaviate -except ImportError: - raise ImportError("Please install the weaviate client") - class WeaviateVS(VS): - def __init__(self, weaviate_collection_name:str, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], weaviate_collection_text_key: Optional[str] = "content"): + def __init__(self): + try: + import weaviate + except ImportError: + weaviate = None + + if weaviate is None: + raise ImportError("Please install the weaviate client") pass From 08dfabab9fc7d36dd85c8fd26d5bdf6b93e745f7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Sun, 12 Jan 2025 22:29:30 -0800 Subject: [PATCH 05/23] added tests for clustering and filtering --- tests/test_cluster.py | 104 ++++++++++++++++++++++++++++++++++++++++++ tests/test_filter.py | 102 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 tests/test_cluster.py create mode 100644 tests/test_filter.py diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 00000000..5266feae --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,104 @@ +import pandas as pd +import pytest + +from tests.base_test import BaseTest + + +@pytest.fixture +def sample_df(): + return pd.DataFrame({ + "Course Name": [ + "Probability and Random Processes", + "Statistics and Data Analysis", + "Cooking Basics", + "Advanced Culinary Arts", + "Digital Circuit Design", + "Computer Architecture" + ] + }) + + +class TestClusterBy(BaseTest): + def test_basic_clustering(self, sample_df): + """Test basic clustering functionality with 2 clusters""" + result = sample_df.sem_cluster_by("Course Name", 2) + assert "cluster_id" in result.columns + assert len(result["cluster_id"].unique()) == 2 + assert len(result) == len(sample_df) + + + # Get the two clusters + cluster_0_courses = set(result[result["cluster_id"] == 0]["Course Name"]) + cluster_1_courses = set(result[result["cluster_id"] == 1]["Course Name"]) + + # Define the expected course groupings + tech_courses = { + "Probability and Random Processes", + "Statistics and Data Analysis", + "Digital Circuit Design", + "Computer Architecture" + } + culinary_courses = { + "Cooking Basics", + "Advanced Culinary Arts" + } + + # Check that one cluster contains tech courses and the other contains culinary courses + assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or \ + (cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses), \ + "Clusters don't match expected course groupings" + + def test_clustering_with_more_clusters(self, sample_df): + """Test clustering with more clusters than necessary""" + result = sample_df.sem_cluster_by("Course Name", 3) + assert len(result["cluster_id"].unique()) == 3 + assert len(result) == len(sample_df) + + def test_clustering_with_single_cluster(self, sample_df): + """Test clustering with single cluster""" + result = sample_df.sem_cluster_by("Course Name", 1) + assert len(result["cluster_id"].unique()) == 1 + assert result["cluster_id"].iloc[0] == 0 + + def test_clustering_with_invalid_column(self, sample_df): + """Test clustering with non-existent column""" + with pytest.raises(ValueError, match="Column .* not found in DataFrame"): + sample_df.sem_cluster_by("NonExistentColumn", 2) + + def test_clustering_with_empty_dataframe(self): + """Test clustering on empty dataframe""" + empty_df = pd.DataFrame(columns=["Course Name"]) + result = empty_df.sem_cluster_by("Course Name", 2) + assert len(result) == 0 + assert "cluster_id" in result.columns + + def test_clustering_similar_items(self, sample_df): + """Test that similar items are clustered together""" + result = sample_df.sem_cluster_by("Course Name", 3) + + # Get cluster IDs for similar courses + stats_cluster = result[result["Course Name"].str.contains("Statistics")]["cluster_id"].iloc[0] + prob_cluster = result[result["Course Name"].str.contains("Probability")]["cluster_id"].iloc[0] + + # Similar courses should be in the same cluster + assert stats_cluster == prob_cluster + + cooking_cluster = result[result["Course Name"].str.contains("Cooking")]["cluster_id"].iloc[0] + culinary_cluster = result[result["Course Name"].str.contains("Culinary")]["cluster_id"].iloc[0] + + assert cooking_cluster == culinary_cluster + + def test_clustering_with_verbose(self, sample_df): + """Test clustering with verbose output""" + result = sample_df.sem_cluster_by("Course Name", 2, verbose=True) + assert "cluster_id" in result.columns + assert len(result["cluster_id"].unique()) == 2 + + def test_clustering_with_iterations(self, sample_df): + """Test clustering with different iteration counts""" + result1 = sample_df.sem_cluster_by("Course Name", 2, niter=5) + result2 = sample_df.sem_cluster_by("Course Name", 2, niter=20) + + # Both should produce valid clusterings + assert len(result1["cluster_id"].unique()) == 2 + assert len(result2["cluster_id"].unique()) == 2 diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 00000000..c3023824 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,102 @@ +import pandas as pd +import pytest + +from lotus.types import CascadeArgs +from tests.base_test import BaseTest + + +@pytest.fixture +def sample_df(): + return pd.DataFrame({ + "Name": ["Alice", "Bob", "Charlie"], + "Age": [25, 30, 17], + "City": ["New York", "London", "Paris"] + }) + + +class TestFilteredSearch(BaseTest): + def test_basic_filter(self, sample_df): + """Test basic filtering functionality""" + result = sample_df.sem_filter("Age greater than 20") + assert len(result) == 2 + assert all(age > 20 for age in result["Age"]) + + def test_filter_with_examples(self, sample_df): + """Test filtering with example data""" + examples = pd.DataFrame({ + "Name": ["David", "Eve"], + "Age": [40, 15], + "City": ["Berlin", "Tokyo"], + "Answer": [True, False] + }) + result = sample_df.sem_filter( + "Age greater than 20", + examples=examples + ) + assert len(result) == 2 + assert all(age > 20 for age in result["Age"]) + + def test_filter_with_explanations(self, sample_df): + """Test filtering with explanations returned""" + result = sample_df.sem_filter( + "Age greater than 20", + return_explanations=True + ) + assert "explanation_filter" in result.columns + assert len(result["explanation_filter"]) == len(result) + + def test_filter_with_raw_outputs(self, sample_df): + """Test filtering with raw outputs returned""" + result = sample_df.sem_filter( + "Age greater than 20", + return_raw_outputs=True + ) + assert "raw_output_filter" in result.columns + assert len(result["raw_output_filter"]) == len(result) + + def test_filter_with_cot_strategy(self, sample_df): + """Test filtering with chain-of-thought reasoning""" + examples = pd.DataFrame({ + "Name": ["David"], + "Age": [40], + "City": ["Berlin"], + "Answer": [True], + "Reasoning": ["The age is 40, which is greater than 20"] + }) + result = sample_df.sem_filter( + "Age greater than 20", + examples=examples, + strategy="cot", + return_explanations=True + ) + assert "explanation_filter" in result.columns + assert len(result) == 2 + + def test_filter_with_invalid_column(self, sample_df): + """Test filtering with non-existent column""" + with pytest.raises(ValueError, match="Column .* not found in DataFrame"): + sample_df.sem_filter("InvalidColumn greater than 20") + + def test_filter_with_cascade(self, sample_df): + """Test filtering with cascade arguments""" + cascade_args = CascadeArgs( + recall_target=0.9, + precision_target=0.9, + sampling_percentage=0.1, + failure_probability=0.2 + ) + result, stats = sample_df.sem_filter( + "Age greater than 20", + cascade_args=cascade_args, + return_stats=True + ) + assert isinstance(stats, dict) + assert "pos_cascade_threshold" in stats + assert "neg_cascade_threshold" in stats + assert len(result) == 2 + + def test_empty_dataframe(self): + """Test filtering on empty dataframe""" + empty_df = pd.DataFrame(columns=["Name", "Age", "City"]) + result = empty_df.sem_filter("Age greater than 20") + assert len(result) == 0 \ No newline at end of file From f3a82c1f80390c1b5cff6e1b03552834c2fc2cc5 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Mon, 13 Jan 2025 12:05:44 -0800 Subject: [PATCH 06/23] made edits to test_filter --- tests/test_filter.py | 170 ++++++++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 81 deletions(-) diff --git a/tests/test_filter.py b/tests/test_filter.py index c3023824..1611340a 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,102 +1,110 @@ import pandas as pd import pytest -from lotus.types import CascadeArgs from tests.base_test import BaseTest @pytest.fixture def sample_df(): return pd.DataFrame({ - "Name": ["Alice", "Bob", "Charlie"], - "Age": [25, 30, 17], - "City": ["New York", "London", "Paris"] + "Course Name": [ + "Introduction to Programming", + "Advanced Programming", + "Cooking Basics", + "Advanced Culinary Arts", + "Data Structures", + "Algorithms", + "French Cuisine", + "Italian Cooking" + ], + "Department": [ + "CS", "CS", "Culinary", "Culinary", + "CS", "CS", "Culinary", "Culinary" + ], + "Level": [ + 100, 200, 100, 200, + 300, 300, 200, 200 + ] }) -class TestFilteredSearch(BaseTest): - def test_basic_filter(self, sample_df): - """Test basic filtering functionality""" - result = sample_df.sem_filter("Age greater than 20") +class TestSearch(BaseTest): + def test_basic_search(self, sample_df): + """Test basic semantic search functionality""" + df = sample_df.sem_index("Course Name", "course_index") + result = df.sem_search("Course Name", "programming courses", K=2) assert len(result) == 2 - assert all(age > 20 for age in result["Age"]) + assert "Introduction to Programming" in result["Course Name"].values + assert "Advanced Programming" in result["Course Name"].values - def test_filter_with_examples(self, sample_df): - """Test filtering with example data""" - examples = pd.DataFrame({ - "Name": ["David", "Eve"], - "Age": [40, 15], - "City": ["Berlin", "Tokyo"], - "Answer": [True, False] - }) - result = sample_df.sem_filter( - "Age greater than 20", - examples=examples - ) + def test_filtered_search_relational(self, sample_df): + """Test semantic search with relational filter""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply relational filter and search + filtered_df = df[df["Department"] == "CS"] + result = filtered_df.sem_search("Course Name", "advanced courses", K=2) + assert len(result) == 2 - assert all(age > 20 for age in result["Age"]) - - def test_filter_with_explanations(self, sample_df): - """Test filtering with explanations returned""" - result = sample_df.sem_filter( - "Age greater than 20", - return_explanations=True - ) - assert "explanation_filter" in result.columns - assert len(result["explanation_filter"]) == len(result) + # Should only return CS courses + assert all(dept == "CS" for dept in result["Department"]) + assert "Advanced Programming" in result["Course Name"].values - def test_filter_with_raw_outputs(self, sample_df): - """Test filtering with raw outputs returned""" - result = sample_df.sem_filter( - "Age greater than 20", - return_raw_outputs=True - ) - assert "raw_output_filter" in result.columns - assert len(result["raw_output_filter"]) == len(result) + def test_filtered_search_semantic(self, sample_df): + """Test semantic search after semantic filter""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply semantic filter and search + filtered_df = df.sem_filter("{Course Name} is related to cooking") + result = filtered_df.sem_search("Course Name", "advanced level courses", K=2) + + assert len(result) == 2 + # Should only return cooking-related courses + assert all(dept == "Culinary" for dept in result["Department"]) + assert "Advanced Culinary Arts" in result["Course Name"].values - def test_filter_with_cot_strategy(self, sample_df): - """Test filtering with chain-of-thought reasoning""" - examples = pd.DataFrame({ - "Name": ["David"], - "Age": [40], - "City": ["Berlin"], - "Answer": [True], - "Reasoning": ["The age is 40, which is greater than 20"] - }) - result = sample_df.sem_filter( - "Age greater than 20", - examples=examples, - strategy="cot", - return_explanations=True - ) - assert "explanation_filter" in result.columns + def test_filtered_search_combined(self, sample_df): + """Test semantic search with both relational and semantic filters""" + # Index the dataframe + df = sample_df.sem_index("Course Name", "course_index") + + # Apply both filters and search + filtered_df = df[df["Level"] >= 200] # relational filter + filtered_df = filtered_df.sem_filter("{Course Name} is related to computer science") # semantic filter + result = filtered_df.sem_search("Course Name", "data structures and algorithms", K=2) + assert len(result) == 2 + # Should only return advanced CS courses + assert all(dept == "CS" for dept in result["Department"]) + assert all(level >= 200 for level in result["Level"]) + assert "Data Structures" in result["Course Name"].values + assert "Algorithms" in result["Course Name"].values - def test_filter_with_invalid_column(self, sample_df): - """Test filtering with non-existent column""" - with pytest.raises(ValueError, match="Column .* not found in DataFrame"): - sample_df.sem_filter("InvalidColumn greater than 20") + def test_filtered_search_empty_result(self, sample_df): + """Test semantic search when filter returns empty result""" + df = sample_df.sem_index("Course Name", "course_index") + + # Apply filter that should return no results + filtered_df = df[df["Level"] > 1000] + result = filtered_df.sem_search("Course Name", "any course", K=2) + + assert len(result) == 0 - def test_filter_with_cascade(self, sample_df): - """Test filtering with cascade arguments""" - cascade_args = CascadeArgs( - recall_target=0.9, - precision_target=0.9, - sampling_percentage=0.1, - failure_probability=0.2 - ) - result, stats = sample_df.sem_filter( - "Age greater than 20", - cascade_args=cascade_args, - return_stats=True + def test_filtered_search_with_scores(self, sample_df): + """Test filtered semantic search with similarity scores""" + df = sample_df.sem_index("Course Name", "course_index") + + filtered_df = df[df["Department"] == "CS"] + result = filtered_df.sem_search( + "Course Name", + "programming courses", + K=2, + return_scores=True ) - assert isinstance(stats, dict) - assert "pos_cascade_threshold" in stats - assert "neg_cascade_threshold" in stats - assert len(result) == 2 - - def test_empty_dataframe(self): - """Test filtering on empty dataframe""" - empty_df = pd.DataFrame(columns=["Name", "Age", "City"]) - result = empty_df.sem_filter("Age greater than 20") - assert len(result) == 0 \ No newline at end of file + + assert "vec_scores_sim_score" in result.columns + assert len(result["vec_scores_sim_score"]) == 2 + # Scores should be between 0 and 1 + assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) \ No newline at end of file From fc62846cbda59281499824af33de0bae53661c00 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Mon, 13 Jan 2025 22:07:32 -0800 Subject: [PATCH 07/23] added implementations for weaviate and pinecone vs --- lotus/vector_store/pinecone_vs.py | 154 +++++++++++++++++++++++++-- lotus/vector_store/vs.py | 30 +++++- lotus/vector_store/weaviate_vs.py | 171 ++++++++++++++++++++++++++++-- 3 files changed, 335 insertions(+), 20 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index fdb89800..19e7f810 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,17 +1,153 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from pinecone import Pinecone +except ImportError as err: + raise ImportError( + "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + ) from err class PineconeVS(VS): - def __init__(self): - try: - import pinecone - except ImportError: - pinecone = None + def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize Pinecone client with API key and environment""" + super().__init__(embedding_model) + self.pinecone = Pinecone(api_key=api_key) + self.index = None + self.max_batch_size = max_batch_size - if pinecone is None: - raise ImportError( - "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", + def index(self, docs: pd.Series, collection_name: str): + """Create an index and add documents to it""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + dimension = sample_embedding.shape[1] + + # Check if index already exists + if collection_name not in self.pinecone.list_indexes(): + # Create new index with the correct dimension + self.pinecone.create_index( + name=collection_name, + dimension=dimension, + metric="cosine" ) + + # Connect to index + self.pc_index = self.pinecone.Index(collection_name) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Create embeddings using the provided embedding model + embeddings = self._batch_embed(docs_list) + + # Prepare vectors for upsert + vectors = [] + for idx, (embedding, doc) in enumerate(zip(embeddings, docs_list)): + vectors.append({ + "id": str(idx), + "values": embedding.tolist(), # Pinecone expects lists, not numpy arrays + "metadata": { + "content": doc, + "doc_id": idx + } + }) + + # Upsert in batches of 100 + batch_size = 100 + for i in tqdm(range(0, len(vectors), batch_size), desc="Uploading to Pinecone"): + batch = vectors[i:i + batch_size] + self.pc_index.upsert(vectors=batch) + + def load_index(self, collection_name: str): + """Connect to an existing Pinecone index""" + if collection_name not in self.pinecone.list_indexes(): + raise ValueError(f"Index {collection_name} not found") + + self.collection_name = collection_name + self.pc_index = self.pinecone.Index(collection_name) + + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Pinecone""" + if self.pc_index is None: + raise ValueError("No index loaded. Call load_index first.") + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + # Query Pinecone + results = self.index.query( + vector=query_vector.tolist(), + top_k=K, + include_metadata=True, + **kwargs + ) + + # Extract distances and indices + distances = [] + indices = [] + + for match in results.matches: + indices.append(int(match.metadata["doc_id"])) + distances.append(match.score) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) # Use -1 for padding + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.pc_index is None or self.collection_name != collection_name: + self.load_index(collection_name) + + # Fetch vectors from Pinecone + vectors = [] + for doc_id in ids: + response = self.pc_index.fetch(ids=[str(doc_id)]) + if str(doc_id) in response.vectors: + vector = response.vectors[str(doc_id)].values + vectors.append(vector) + else: + raise ValueError(f"Document with id {doc_id} not found") - pass + return np.array(vectors, dtype=np.float64) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index bb4878eb..89a7bce8 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,26 +1,38 @@ from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any import numpy as np import pandas as pd +import tqdm from numpy.typing import NDArray from PIL import Image +from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput class VS(ABC): """Abstract class for vector stores.""" - def __init__(self) -> None: + def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: + self.collection_name: str | None = None + self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model pass @abstractmethod - def index(self, docs: pd.Series, index_dir): + def index(self, docs: pd.Series, collection_name: str): + """ + Create index and store it in vector store + """ pass + @abstractmethod + def load_index(self, collection_name: str): + """Load the index from the vector store into memory ?? (not sure if this is needed )""" + @abstractmethod - def search(self, + def __call__(self, queries: pd.Series | str | Image.Image | list | NDArray[np.float64], K:int, **kwargs: dict[str, Any], @@ -29,4 +41,14 @@ def search(self, @abstractmethod def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: - pass \ No newline at end of file + pass + + def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: + """Create embeddings using the provided embedding model with batching""" + all_embeddings = [] + for i in tqdm(range(0, len(docs), self.max_batch_size), desc="Creating embeddings"): + batch = docs[i : i + self.max_batch_size] + _batch = convert_to_base_data(batch) + embeddings = self._embed(_batch) + all_embeddings.append(embeddings) + return np.vstack(all_embeddings) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 3b747235..7e84a2a4 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,13 +1,170 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from uuid import uuid4 + + import weaviate + from weaviate.util import get_valid_uuid +except ImportError as err: + raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self): + def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with Weaviate client and embedding model""" + super().__init__(embedding_model) + self.client = weaviate_client + self.max_batch_size = max_batch_size + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + vector_dim = sample_embedding.shape[1] + + # Create collection without vectorizer config (we'll provide vectors directly) + collection = self.client.collections.create( + name=collection_name, + properties=[ + { + "name": "content", + "dataType": ["text"], + }, + { + "name": "doc_id", + "dataType": ["int"], + } + ], + vectorizer_config=None, # No vectorizer needed as we provide vectors + vector_index_config={"distance": "cosine"}, + vectorIndexConfig={ + "distance": "cosine", + "dimension": vector_dim + } + ) + + # Generate embeddings for all documents + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + embeddings = self._batch_embed(docs_list) + + # Add documents to collection with their embeddings + with collection.batch.dynamic() as batch: + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + properties = { + "content": doc, + "doc_id": idx + } + batch.add_object( + properties=properties, + vector=embedding.tolist(), # Provide pre-computed vector + uuid=get_valid_uuid(str(uuid4())) + ) + + def load_index(self, collection_name: str): + """Load/set the collection name to use""" + self.collection_name = collection_name + # Verify collection exists try: - import weaviate - except ImportError: - weaviate = None + self.client.collections.get(collection_name) + except weaviate.exceptions.UnexpectedStatusCodeException: + raise ValueError(f"Collection {collection_name} not found") + + def __call__(self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using pre-computed query vectors""" + if self.collection_name is None: + raise ValueError("No collection loaded. Call load_index first.") + + collection = self.client.collections.get(self.collection_name) + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Generate embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + results = [] + for query_vector in query_vectors: + response = (collection.query + .near_vector({ + "vector": query_vector.tolist() + }) + .with_limit(K) + .with_additional(['distance']) + .with_fields(['doc_id']) + .do()) + results.append(response) + + # Process results into expected format + all_distances = [] + all_indices = [] + + for result in results: + objects = result.get('data', {}).get('Get', {}).get(self.collection_name, []) + + distances = [] + indices = [] + for obj in objects: + indices.append(obj['doc_id']) + # Convert cosine distance to similarity score + distance = obj.get('_additional', {}).get('distance', 0) + distances.append(1 - distance) # Convert distance to similarity + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + collection = self.client.collections.get(collection_name) + + # Query for documents with specific doc_ids + vectors = [] + for doc_id in ids: + response = (collection.query + .with_fields(['_additional {vector}']) + .with_where({ + 'path': ['doc_id'], + 'operator': 'Equal', + 'valueNumber': doc_id + }) + .do()) - if weaviate is None: - raise ImportError("Please install the weaviate client") - pass + # Extract vector from response + objects = response.get('data', {}).get('Get', {}).get(collection_name, []) + if objects: + vector = objects[0].get('_additional', {}).get('vector', []) + vectors.append(vector) + else: + raise ValueError(f"Document with id {doc_id} not found") + + return np.array(vectors, dtype=np.float64) + + \ No newline at end of file From f2937adcd0ee6f02589e922e0b426a95e8410d0a Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 13:37:21 -0800 Subject: [PATCH 08/23] added extra refactoring and added implementations for qdrant and chroma_vs --- lotus/vector_store/chroma_vs.py | 147 ++++++++++++++++++++++++++-- lotus/vector_store/pinecone_vs.py | 2 +- lotus/vector_store/qdrant_vs.py | 155 ++++++++++++++++++++++++++++-- lotus/vector_store/vs.py | 3 +- lotus/vector_store/weaviate_vs.py | 2 +- 5 files changed, 289 insertions(+), 20 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 298f05e9..d05b1e47 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,16 +1,147 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + import chromadb + from chromadb.api import Collection +except ImportError as err: + raise ImportError( + "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" + ) from err class ChromaVS(VS): - def __init__(self): + def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with ChromaDB client and embedding model""" + super().__init__(embedding_model) + self.client = client + self.collection: Collection | None = None + self.collection_name = None + self.max_batch_size = max_batch_size + + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Create collection without embedding function (we'll provide embeddings directly) + self.collection = self.client.create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity for consistency + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare documents for addition + ids = [str(i) for i in range(len(docs_list))] + metadatas = [{"doc_id": i} for i in range(len(docs_list))] + + # Add documents in batches + batch_size = 100 + for i in tqdm(range(0, len(docs_list), batch_size), desc="Uploading to ChromaDB"): + end_idx = min(i + batch_size, len(docs_list)) + self.collection.add( + ids=ids[i:end_idx], + documents=docs_list[i:end_idx], + embeddings=embeddings[i:end_idx].tolist(), + metadatas=metadatas[i:end_idx] + ) + + def load_index(self, collection_name: str): + """Load an existing collection""" try: - import chromadb - except ImportError: - chromadb = None + self.collection = self.client.get_collection(collection_name) + self.collection_name = collection_name + except ValueError as e: + raise ValueError(f"Collection {collection_name} not found") from e + + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using ChromaDB""" + if self.collection is None: + raise ValueError("No collection loaded. Call load_index first.") + + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + # Perform searches + all_distances = [] + all_indices = [] - if chromadb is None: - raise ImportError( - "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", + for query_vector in query_vectors: + results = self.collection.query( + query_embeddings=[query_vector.tolist()], + n_results=K, + include=['metadatas', 'distances'] ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + if results['metadatas']: + for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): + indices.append(metadata['doc_id']) + # ChromaDB returns squared L2 distances, convert to cosine similarity + # similarity = 1 - (distance / 2) # Convert L2 distance to cosine similarity + distances.append(1 - (distance / 2)) + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection is None or self.collection_name != collection_name: + self.load_index(collection_name) + + # Convert integer ids to strings for ChromaDB + str_ids = [str(id) for id in ids] + + # Get embeddings from ChromaDB + results = self.collection.get( + ids=str_ids, + include=['embeddings'] + ) + + if not results['embeddings']: + raise ValueError("No vectors found for the given ids") + + return np.array(results['embeddings'], dtype=np.float64) + + \ No newline at end of file diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 19e7f810..da8cc343 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -17,7 +17,7 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, environment: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 5ded2d7d..c672727a 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,16 +1,153 @@ +from typing import Any, Callable, Union + +import numpy as np +import pandas as pd +from numpy.typing import NDArray +from PIL import Image +from tqdm import tqdm + +from lotus.types import RMOutput from lotus.vector_store.vs import VS +try: + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, PointStruct, VectorParams +except ImportError as err: + raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self): - try: - import qdrant_client - except ImportError: - qdrant_client = None + def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + """Initialize with Qdrant client and embedding model""" + super().__init__(embedding_model) # Fixed the super() call syntax + self.client = client + self.max_batch_size = max_batch_size + + def index(self, docs: pd.Series, collection_name: str): + """Create a collection and add documents with their embeddings""" + self.collection_name = collection_name + + # Get sample embedding to determine vector dimension + sample_embedding = self._embed([docs.iloc[0]]) + dimension = sample_embedding.shape[1] + + # Create collection if it doesn't exist + if not self.client.collection_exists(collection_name): + self.client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) + ) + + # Convert docs to list if it's a pandas Series + docs_list = docs.tolist() if isinstance(docs, pd.Series) else docs + + # Generate embeddings + embeddings = self._batch_embed(docs_list) + + # Prepare points for upload + points = [] + for idx, (doc, embedding) in enumerate(zip(docs_list, embeddings)): + points.append( + PointStruct( + id=idx, + vector=embedding.tolist(), + payload={ + "content": doc, + "doc_id": idx + } + ) + ) + + # Upload in batches + batch_size = 100 + for i in tqdm(range(0, len(points), batch_size), desc="Uploading to Qdrant"): + batch = points[i:i + batch_size] + self.client.upsert( + collection_name=collection_name, + points=batch + ) + + def load_index(self, collection_name: str): + """Set the collection name to use""" + if not self.client.collection_exists(collection_name): + raise ValueError(f"Collection {collection_name} not found") + self.collection_name = collection_name + def __call__( + self, + queries: Union[pd.Series, str, Image.Image, list, NDArray[np.float64]], + K: int, + **kwargs: dict[str, Any] + ) -> RMOutput: + """Perform vector search using Qdrant""" + if self.collection_name is None: + raise ValueError("No collection loaded. Call load_index first.") - if qdrant_client is None: - raise ImportError( - "The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", + # Convert single query to list + if isinstance(queries, (str, Image.Image)): + queries = [queries] + + # Handle numpy array queries (pre-computed vectors) + if isinstance(queries, np.ndarray): + query_vectors = queries + else: + # Convert queries to list if needed + if isinstance(queries, pd.Series): + queries = queries.tolist() + # Create embeddings for text queries + query_vectors = self._batch_embed(queries) + + # Perform searches + all_distances = [] + all_indices = [] + + for query_vector in query_vectors: + results = self.client.search( + collection_name=self.collection_name, + query_vector=query_vector.tolist(), + limit=K, + with_payload=True ) - pass + + # Extract distances and indices + distances = [] + indices = [] + + for result in results: + indices.append(result.payload["doc_id"]) + distances.append(result.score) # Qdrant returns cosine similarity directly + + # Pad results if fewer than K matches + while len(indices) < K: + indices.append(-1) + distances.append(0.0) + + all_distances.append(distances) + all_indices.append(indices) + + return RMOutput( + distances=np.array(all_distances, dtype=np.float32), + indices=np.array(all_indices, dtype=np.int64) + ) + + def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + """Retrieve vectors for specific document IDs""" + if self.collection_name != collection_name: + self.load_index(collection_name) + + # Fetch points from Qdrant + points = self.client.retrieve( + collection_name=collection_name, + ids=ids, + with_vectors=True, + with_payload=False + ) + + # Extract and return vectors + vectors = [] + for point in points: + if point.vector is not None: + vectors.append(point.vector) + else: + raise ValueError(f"Vector not found for id {point.id}") + + return np.array(vectors, dtype=np.float64) \ No newline at end of file diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 89a7bce8..0f37a1a7 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -18,7 +18,8 @@ class VS(ABC): def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: self.collection_name: str | None = None self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model - pass + self.max_batch_size:int = 64 + @abstractmethod def index(self, docs: pd.Series, collection_name: str): diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 364bca31..e1e957b5 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -17,7 +17,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: Union[weaviate.WeaviateClient, weaviate.Client], embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client From a4c741817b74c3e243f2f0bce79f91445cbc25df Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 15:00:51 -0800 Subject: [PATCH 09/23] fixed some type errors --- lotus/vector_store/weaviate_vs.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index e1e957b5..0de158af 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -12,6 +12,7 @@ from uuid import uuid4 import weaviate + from weaviate.classes.config import Configure, DataType, Property from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err @@ -27,29 +28,21 @@ def index(self, docs: pd.Series, collection_name: str): """Create a collection and add documents with their embeddings""" self.collection_name = collection_name - # Get sample embedding to determine vector dimension - sample_embedding = self._embed([docs.iloc[0]]) - vector_dim = sample_embedding.shape[1] - # Create collection without vectorizer config (we'll provide vectors directly) collection = self.client.collections.create( name=collection_name, properties=[ - { - "name": "content", - "dataType": ["text"], - }, - { - "name": "doc_id", - "dataType": ["int"], - } + Property( + name='content', + data_type=DataType.TEXT + ), + Property( + name='doc_id', + data_type=DataType.INT, + ) ], vectorizer_config=None, # No vectorizer needed as we provide vectors - vector_index_config={"distance": "cosine"}, - vectorIndexConfig={ - "distance": "cosine", - "dimension": vector_dim - } + vector_index_config=Configure.VectorIndex.dynamic() ) # Generate embeddings for all documents From 1357fb339fd5abb7a4492242152bac2bee8a32cd Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:24:45 -0800 Subject: [PATCH 10/23] made further corrections --- lotus/vector_store/weaviate_vs.py | 48 ++++++++++++------------------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 0de158af..d05fe6eb 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -13,6 +13,7 @@ import weaviate from weaviate.classes.config import Configure, DataType, Property + from weaviate.classes.query import MetadataQuery from weaviate.util import get_valid_uuid except ImportError as err: raise ImportError("Please install the weaviate client") from err @@ -58,7 +59,7 @@ def index(self, docs: pd.Series, collection_name: str): } batch.add_object( properties=properties, - vector=embedding.tolist(), # Provide pre-computed vector + vector=embedding.tolist(), # Provide pre-computed vector uuid=get_valid_uuid(str(uuid4())) ) @@ -97,13 +98,11 @@ def __call__(self, results = [] for query_vector in query_vectors: response = (collection.query - .near_vector({ - "vector": query_vector.tolist() - }) - .with_limit(K) - .with_additional(['distance']) - .with_fields(['doc_id']) - .do()) + .near_vector( + near_vector=query_vector.tolist(), + limit=K, + return_metadata=MetadataQuery(distance=True) + )) results.append(response) # Process results into expected format @@ -111,14 +110,14 @@ def __call__(self, all_indices = [] for result in results: - objects = result.get('data', {}).get('Get', {}).get(self.collection_name, []) + objects = result.objects distances = [] indices = [] for obj in objects: - indices.append(obj['doc_id']) + indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance = obj.get('_additional', {}).get('distance', 0) + distance = obj.metadata.distance distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches @@ -130,8 +129,8 @@ def __call__(self, all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: @@ -140,23 +139,14 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra # Query for documents with specific doc_ids vectors = [] - for doc_id in ids: - response = (collection.query - .with_fields(['_additional {vector}']) - .with_where({ - 'path': ['doc_id'], - 'operator': 'Equal', - 'valueNumber': doc_id - }) - .do()) - - # Extract vector from response - objects = response.get('data', {}).get('Get', {}).get(collection_name, []) - if objects: - vector = objects[0].get('_additional', {}).get('vector', []) - vectors.append(vector) + + response = collection.query.fetch_objects_by_ids(ids=ids) + for id in ids: + response = collection.query.fetch_object_by_id(uuid=id) + if response: + vectors.append(response.vector) else: - raise ValueError(f"Document with id {doc_id} not found") + raise ValueError(f'{id} does not exist in {collection_name}') return np.array(vectors, dtype=np.float64) From c76b658519512dc4ab52e997efa257eade523c15 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:33:53 -0800 Subject: [PATCH 11/23] edit uuid type --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index d05fe6eb..eab46d0c 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -133,7 +133,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[uuid4]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From 9f257f77eab80be60ea30a876414f6294287b5ed Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 18:57:31 -0800 Subject: [PATCH 12/23] changed uuid type --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index eab46d0c..82fff0b4 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -133,7 +133,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[uuid4]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From 99cb535ad92f0bcab42f818b8366e973dd4c8ed7 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:15:43 -0800 Subject: [PATCH 13/23] made type changes to weaviate file --- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 0f37a1a7..c370df61 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -41,7 +41,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 82fff0b4..26585cc8 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, List, Union import numpy as np import pandas as pd @@ -103,6 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) + response.objects[0].metadata.distance results.append(response) # Process results into expected format @@ -112,12 +113,12 @@ def __call__(self, for result in results: objects = result.objects - distances = [] + distances:List[float] = [] indices = [] for obj in objects: indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance = obj.metadata.distance + distance:float = obj.metadata.distance distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches @@ -133,21 +134,19 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[str]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) # Query for documents with specific doc_ids vectors = [] - response = collection.query.fetch_objects_by_ids(ids=ids) for id in ids: response = collection.query.fetch_object_by_id(uuid=id) if response: vectors.append(response.vector) else: raise ValueError(f'{id} does not exist in {collection_name}') - return np.array(vectors, dtype=np.float64) From 3c8a742f5123888f3c0197acee18228893d22fa9 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:21:27 -0800 Subject: [PATCH 14/23] made another change --- lotus/vector_store/vs.py | 2 +- lotus/vector_store/weaviate_vs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index c370df61..7d3e2c00 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -41,7 +41,7 @@ def __call__(self, pass @abstractmethod - def get_vectors_from_index(self, collection_name:str, ids: list[any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name:str, ids: list[Any]) -> NDArray[np.float64]: pass def _batch_embed(self, docs: pd.Series | list) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 26585cc8..702fb510 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -134,7 +134,7 @@ def __call__(self, indices=np.array(all_indices, dtype=np.int64).tolist() ) - def get_vectors_from_index(self, collection_name: str, ids: list[any]) -> NDArray[np.float64]: + def get_vectors_from_index(self, collection_name: str, ids: list[Any]) -> NDArray[np.float64]: """Retrieve vectors for specific document IDs""" collection = self.client.collections.get(collection_name) From ccd9e489a5ff23d6a85e3420b5a3972d8f18f544 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Tue, 14 Jan 2025 22:30:20 -0800 Subject: [PATCH 15/23] typecheck passes for weaviate? --- lotus/vector_store/weaviate_vs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 702fb510..44eb7463 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -118,9 +118,8 @@ def __call__(self, for obj in objects: indices.append(obj.properties.get('content')) # Convert cosine distance to similarity score - distance:float = obj.metadata.distance - distances.append(1 - distance) # Convert distance to similarity - + distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 + distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: indices.append(-1) From 89bf9743ec5358241ac96595d05a3ffbba075739 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:09:59 -0800 Subject: [PATCH 16/23] type changes for weaviate and qdrant files --- lotus/vector_store/qdrant_vs.py | 6 +++--- lotus/vector_store/weaviate_vs.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index c672727a..28a7c7dd 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -113,7 +113,7 @@ def __call__( indices = [] for result in results: - indices.append(result.payload["doc_id"]) + indices.append(result.id) distances.append(result.score) # Qdrant returns cosine similarity directly # Pad results if fewer than K matches @@ -125,8 +125,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 44eb7463..9f18ce0c 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -103,7 +103,7 @@ def __call__(self, limit=K, return_metadata=MetadataQuery(distance=True) )) - response.objects[0].metadata.distance + response.objects[0].uuid results.append(response) # Process results into expected format @@ -116,7 +116,7 @@ def __call__(self, distances:List[float] = [] indices = [] for obj in objects: - indices.append(obj.properties.get('content')) + indices.append(obj.uuid) # Convert cosine distance to similarity score distance = obj.metadata.distance if obj.metadata and obj.metadata.distance is not None else 1.0 distances.append(1 - distance) # Convert distance to similarity From a76adb76c29dfc3c21d409077e9b08c411e3a746 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:18:23 -0800 Subject: [PATCH 17/23] made changes to weaviate file --- lotus/vector_store/weaviate_vs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 9f18ce0c..e12f34e2 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -9,6 +9,7 @@ from lotus.vector_store.vs import VS try: + import uuid from uuid import uuid4 import weaviate @@ -122,7 +123,7 @@ def __call__(self, distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: - indices.append(-1) + indices.append(uuid.UUID(0)) distances.append(0.0) all_distances.append(distances) From c3e0f0c9e4f54657a7647dff523f002e11e46918 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 17:29:03 -0800 Subject: [PATCH 18/23] made changes to weaviate file --- lotus/vector_store/weaviate_vs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index e12f34e2..8d9937fa 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -123,7 +123,7 @@ def __call__(self, distances.append(1 - distance) # Convert distance to similarity # Pad results if fewer than K matches while len(indices) < K: - indices.append(uuid.UUID(0)) + indices.append(uuid.UUID(int=0)) distances.append(0.0) all_distances.append(distances) From 1782281d957bc7f59b09dd05f55c4f522cc40d62 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 22:14:07 -0800 Subject: [PATCH 19/23] fixed pinecone type errors --- lotus/vector_store/pinecone_vs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index da8cc343..34867018 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -21,7 +21,7 @@ def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], N """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) - self.index = None + self.pc_index = None self.max_batch_size = max_batch_size @@ -107,7 +107,7 @@ def __call__( for query_vector in query_vectors: # Query Pinecone - results = self.index.query( + results = self.pc_index.query( vector=query_vector.tolist(), top_k=K, include_metadata=True, @@ -131,8 +131,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: From 0621b9baffaa19e6d5142bcaa53b707a79d1266b Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Wed, 15 Jan 2025 22:46:15 -0800 Subject: [PATCH 20/23] fixed pinecone type errors --- lotus/vector_store/pinecone_vs.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 34867018..892faaa2 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -10,7 +10,7 @@ from lotus.vector_store.vs import VS try: - from pinecone import Pinecone + from pinecone import Index, Pinecone except ImportError as err: raise ImportError( "The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", @@ -21,7 +21,7 @@ def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], N """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) - self.pc_index = None + self.pc_index:Index | None = None self.max_batch_size = max_batch_size @@ -140,6 +140,11 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra if self.pc_index is None or self.collection_name != collection_name: self.load_index(collection_name) + if self.pc_index is None: # Add this check after load_index + raise ValueError("Failed to initialize Pinecone index") + + + # Fetch vectors from Pinecone vectors = [] for doc_id in ids: From b568d1ef5868f0538c8cd3abbb3571aac50366f0 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 09:58:33 -0800 Subject: [PATCH 21/23] type checks all pass locally --- lotus/vector_store/chroma_vs.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index d05b1e47..3735f798 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Callable, Mapping, Union import numpy as np import pandas as pd @@ -11,14 +11,16 @@ try: import chromadb + from chromadb import ClientAPI from chromadb.api import Collection + from chromadb.api.types import IncludeEnum except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" ) from err class ChromaVS(VS): - def __init__(self, client: chromadb.Client, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client @@ -45,7 +47,7 @@ def index(self, docs: pd.Series, collection_name: str): # Prepare documents for addition ids = [str(i) for i in range(len(docs_list))] - metadatas = [{"doc_id": i} for i in range(len(docs_list))] + metadatas: list[Mapping[str, Union[str, int, float, bool]]] = [{"doc_id": int(i)} for i in range(len(docs_list))] # Add documents in batches batch_size = 100 @@ -98,14 +100,14 @@ def __call__( results = self.collection.query( query_embeddings=[query_vector.tolist()], n_results=K, - include=['metadatas', 'distances'] + include=[IncludeEnum.metadatas, IncludeEnum.distances] ) # Extract distances and indices distances = [] indices = [] - if results['metadatas']: + if results['metadatas'] and results['distances']: for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): indices.append(metadata['doc_id']) # ChromaDB returns squared L2 distances, convert to cosine similarity @@ -121,8 +123,8 @@ def __call__( all_indices.append(indices) return RMOutput( - distances=np.array(all_distances, dtype=np.float32), - indices=np.array(all_indices, dtype=np.int64) + distances=np.array(all_distances, dtype=np.float32).tolist(), + indices=np.array(all_indices, dtype=np.int64).tolist() ) def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArray[np.float64]: @@ -130,13 +132,18 @@ def get_vectors_from_index(self, collection_name: str, ids: list[int]) -> NDArra if self.collection is None or self.collection_name != collection_name: self.load_index(collection_name) + + if self.collection is None: # Add this check after load_index + raise ValueError(f"Failed to load collection {collection_name}") + + # Convert integer ids to strings for ChromaDB str_ids = [str(id) for id in ids] # Get embeddings from ChromaDB results = self.collection.get( ids=str_ids, - include=['embeddings'] + include=[IncludeEnum.embeddings] ) if not results['embeddings']: From 9b33a1f832e7e7c96af5134c8161a19ffe1dae67 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 09:59:16 -0800 Subject: [PATCH 22/23] fixed linting errors --- lotus/vector_store/chroma_vs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 3735f798..7a7e8de9 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -10,10 +10,9 @@ from lotus.vector_store.vs import VS try: - import chromadb - from chromadb import ClientAPI + from chromadb import ClientAPI from chromadb.api import Collection - from chromadb.api.types import IncludeEnum + from chromadb.api.types import IncludeEnum except ImportError as err: raise ImportError( "The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`" From 820f3beff44ee9ed123d233695b5d1a3d29a2220 Mon Sep 17 00:00:00 2001 From: Amogh Tantradi Date: Thu, 16 Jan 2025 21:30:00 -0800 Subject: [PATCH 23/23] made refactors to allow for testing --- .github/tests/rm_tests.py | 6 ++++++ lotus/vector_store/chroma_vs.py | 4 ++-- lotus/vector_store/pinecone_vs.py | 4 ++-- lotus/vector_store/qdrant_vs.py | 4 ++-- lotus/vector_store/vs.py | 23 +++++++++++++++++++---- lotus/vector_store/weaviate_vs.py | 4 ++-- 6 files changed, 33 insertions(+), 12 deletions(-) diff --git a/.github/tests/rm_tests.py b/.github/tests/rm_tests.py index 2c00e116..f9bb5fae 100644 --- a/.github/tests/rm_tests.py +++ b/.github/tests/rm_tests.py @@ -41,9 +41,15 @@ def setup_models(): for model_name in ENABLED_MODEL_NAMES: models[model_name] = MODEL_NAME_TO_CLS[model_name](model=model_name) + + return models +@pytest.fixture(scope='session') +def setup_vs(): + pass + ################################################################################ # RM Only Tests ################################################################################ diff --git a/lotus/vector_store/chroma_vs.py b/lotus/vector_store/chroma_vs.py index 7a7e8de9..c304658e 100644 --- a/lotus/vector_store/chroma_vs.py +++ b/lotus/vector_store/chroma_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Mapping, Union +from typing import Any, Mapping, Union import numpy as np import pandas as pd @@ -19,7 +19,7 @@ ) from err class ChromaVS(VS): - def __init__(self, client: ClientAPI, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: ClientAPI, embedding_model: str, max_batch_size: int = 64): """Initialize with ChromaDB client and embedding model""" super().__init__(embedding_model) self.client = client diff --git a/lotus/vector_store/pinecone_vs.py b/lotus/vector_store/pinecone_vs.py index 892faaa2..2aeedda8 100644 --- a/lotus/vector_store/pinecone_vs.py +++ b/lotus/vector_store/pinecone_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np import pandas as pd @@ -17,7 +17,7 @@ ) from err class PineconeVS(VS): - def __init__(self, api_key: str, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, api_key: str, embedding_model: str, max_batch_size: int = 64): """Initialize Pinecone client with API key and environment""" super().__init__(embedding_model) self.pinecone = Pinecone(api_key=api_key) diff --git a/lotus/vector_store/qdrant_vs.py b/lotus/vector_store/qdrant_vs.py index 28a7c7dd..f6ae180b 100644 --- a/lotus/vector_store/qdrant_vs.py +++ b/lotus/vector_store/qdrant_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Union +from typing import Any, Union import numpy as np import pandas as pd @@ -16,7 +16,7 @@ raise ImportError("Please install the qdrant client") from err class QdrantVS(VS): - def __init__(self, client: QdrantClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, client: QdrantClient, embedding_model: str, max_batch_size: int = 64): """Initialize with Qdrant client and embedding model""" super().__init__(embedding_model) # Fixed the super() call syntax self.client = client diff --git a/lotus/vector_store/vs.py b/lotus/vector_store/vs.py index 7d3e2c00..2ad4c063 100644 --- a/lotus/vector_store/vs.py +++ b/lotus/vector_store/vs.py @@ -1,24 +1,39 @@ from abc import ABC, abstractmethod -from collections.abc import Callable from typing import Any import numpy as np import pandas as pd import tqdm +from litellm import embedding from numpy.typing import NDArray from PIL import Image +from sentence_transformers import CrossEncoder, SentenceTransformer from lotus.dtype_extensions import convert_to_base_data from lotus.types import RMOutput +MODEL_NAME_TO_CLS = { + "intfloat/e5-small-v2": lambda model: SentenceTransformer(model_name_or_path=model), + "mixedbread-ai/mxbai-rerank-xsmall-v1": lambda model: CrossEncoder(model_name=model), + "text-embedding-3-small": lambda model: lambda batch: embedding(model=model, input=batch), +} + + +def initialize(model_name): + if model_name == 'intfloat/e5-small-v2': + return SentenceTransformer(model_name=model_name) + elif model_name== 'mixedbread-ai/mxbai-rerank-xsmall-v1': + return CrossEncoder(model_name=model_name) + return lambda batch: embedding(model=model_name, input=batch) + class VS(ABC): """Abstract class for vector stores.""" - def __init__(self, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]]) -> None: + def __init__(self, embedding_model: str) -> None: self.collection_name: str | None = None - self._embed: Callable[[pd.Series | list], NDArray[np.float64]] = embedding_model - self.max_batch_size:int = 64 + self._embed = initialize(embedding_model) + self.max_batch_size:int = 64 @abstractmethod diff --git a/lotus/vector_store/weaviate_vs.py b/lotus/vector_store/weaviate_vs.py index 8d9937fa..786fe49a 100644 --- a/lotus/vector_store/weaviate_vs.py +++ b/lotus/vector_store/weaviate_vs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Union +from typing import Any, List, Union import numpy as np import pandas as pd @@ -20,7 +20,7 @@ raise ImportError("Please install the weaviate client") from err class WeaviateVS(VS): - def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: Callable[[pd.Series | list], NDArray[np.float64]], max_batch_size: int = 64): + def __init__(self, weaviate_client: weaviate.WeaviateClient, embedding_model: str, max_batch_size: int = 64): """Initialize with Weaviate client and embedding model""" super().__init__(embedding_model) self.client = weaviate_client