-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial scaffolding for adding vector store / vector database integra…
…tion (#76)
- Loading branch information
1 parent
8a207aa
commit 4d4ca82
Showing
11 changed files
with
324 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from lotus.vector_store.vs import VS | ||
from lotus.vector_store.weaviate_vs import WeaviateVS | ||
from lotus.vector_store.pinecone_vs import PineconeVS | ||
from lotus.vector_store.chroma_vs import ChromaVS | ||
from lotus.vector_store.qdrant_vs import QdrantVS | ||
|
||
__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from lotus.vector_store.vs import VS | ||
|
||
|
||
class ChromaVS(VS): | ||
def __init__(self): | ||
try: | ||
import chromadb | ||
except ImportError: | ||
chromadb = None | ||
|
||
|
||
if chromadb is None: | ||
raise ImportError( | ||
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`", | ||
) | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from lotus.vector_store.vs import VS | ||
|
||
|
||
class PineconeVS(VS): | ||
def __init__(self): | ||
try: | ||
import pinecone | ||
except ImportError: | ||
pinecone = None | ||
|
||
|
||
if pinecone is None: | ||
raise ImportError( | ||
"The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`", | ||
) | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from lotus.vector_store.vs import VS | ||
|
||
|
||
class QdrantVS(VS): | ||
def __init__(self): | ||
try: | ||
import qdrant_client | ||
except ImportError: | ||
qdrant_client = None | ||
|
||
|
||
if qdrant_client is None: | ||
raise ImportError( | ||
"The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`", | ||
) | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from numpy.typing import NDArray | ||
from PIL import Image | ||
|
||
from lotus.types import RMOutput | ||
|
||
|
||
class VS(ABC): | ||
"""Abstract class for vector stores.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def index(self, docs: pd.Series, index_dir): | ||
pass | ||
|
||
@abstractmethod | ||
def search(self, | ||
queries: pd.Series | str | Image.Image | list | NDArray[np.float64], | ||
K:int, | ||
**kwargs: dict[str, Any], | ||
) -> RMOutput: | ||
pass | ||
|
||
@abstractmethod | ||
def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from lotus.vector_store.vs import VS | ||
|
||
|
||
class WeaviateVS(VS): | ||
def __init__(self): | ||
try: | ||
import weaviate | ||
except ImportError: | ||
weaviate = None | ||
|
||
if weaviate is None: | ||
raise ImportError("Please install the weaviate client") | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from tests.base_test import BaseTest | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(): | ||
return pd.DataFrame({ | ||
"Course Name": [ | ||
"Probability and Random Processes", | ||
"Statistics and Data Analysis", | ||
"Cooking Basics", | ||
"Advanced Culinary Arts", | ||
"Digital Circuit Design", | ||
"Computer Architecture" | ||
] | ||
}) | ||
|
||
|
||
class TestClusterBy(BaseTest): | ||
def test_basic_clustering(self, sample_df): | ||
"""Test basic clustering functionality with 2 clusters""" | ||
result = sample_df.sem_cluster_by("Course Name", 2) | ||
assert "cluster_id" in result.columns | ||
assert len(result["cluster_id"].unique()) == 2 | ||
assert len(result) == len(sample_df) | ||
|
||
|
||
# Get the two clusters | ||
cluster_0_courses = set(result[result["cluster_id"] == 0]["Course Name"]) | ||
cluster_1_courses = set(result[result["cluster_id"] == 1]["Course Name"]) | ||
|
||
# Define the expected course groupings | ||
tech_courses = { | ||
"Probability and Random Processes", | ||
"Statistics and Data Analysis", | ||
"Digital Circuit Design", | ||
"Computer Architecture" | ||
} | ||
culinary_courses = { | ||
"Cooking Basics", | ||
"Advanced Culinary Arts" | ||
} | ||
|
||
# Check that one cluster contains tech courses and the other contains culinary courses | ||
assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or \ | ||
(cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses), \ | ||
"Clusters don't match expected course groupings" | ||
|
||
def test_clustering_with_more_clusters(self, sample_df): | ||
"""Test clustering with more clusters than necessary""" | ||
result = sample_df.sem_cluster_by("Course Name", 3) | ||
assert len(result["cluster_id"].unique()) == 3 | ||
assert len(result) == len(sample_df) | ||
|
||
def test_clustering_with_single_cluster(self, sample_df): | ||
"""Test clustering with single cluster""" | ||
result = sample_df.sem_cluster_by("Course Name", 1) | ||
assert len(result["cluster_id"].unique()) == 1 | ||
assert result["cluster_id"].iloc[0] == 0 | ||
|
||
def test_clustering_with_invalid_column(self, sample_df): | ||
"""Test clustering with non-existent column""" | ||
with pytest.raises(ValueError, match="Column .* not found in DataFrame"): | ||
sample_df.sem_cluster_by("NonExistentColumn", 2) | ||
|
||
def test_clustering_with_empty_dataframe(self): | ||
"""Test clustering on empty dataframe""" | ||
empty_df = pd.DataFrame(columns=["Course Name"]) | ||
result = empty_df.sem_cluster_by("Course Name", 2) | ||
assert len(result) == 0 | ||
assert "cluster_id" in result.columns | ||
|
||
def test_clustering_similar_items(self, sample_df): | ||
"""Test that similar items are clustered together""" | ||
result = sample_df.sem_cluster_by("Course Name", 3) | ||
|
||
# Get cluster IDs for similar courses | ||
stats_cluster = result[result["Course Name"].str.contains("Statistics")]["cluster_id"].iloc[0] | ||
prob_cluster = result[result["Course Name"].str.contains("Probability")]["cluster_id"].iloc[0] | ||
|
||
# Similar courses should be in the same cluster | ||
assert stats_cluster == prob_cluster | ||
|
||
cooking_cluster = result[result["Course Name"].str.contains("Cooking")]["cluster_id"].iloc[0] | ||
culinary_cluster = result[result["Course Name"].str.contains("Culinary")]["cluster_id"].iloc[0] | ||
|
||
assert cooking_cluster == culinary_cluster | ||
|
||
def test_clustering_with_verbose(self, sample_df): | ||
"""Test clustering with verbose output""" | ||
result = sample_df.sem_cluster_by("Course Name", 2, verbose=True) | ||
assert "cluster_id" in result.columns | ||
assert len(result["cluster_id"].unique()) == 2 | ||
|
||
def test_clustering_with_iterations(self, sample_df): | ||
"""Test clustering with different iteration counts""" | ||
result1 = sample_df.sem_cluster_by("Course Name", 2, niter=5) | ||
result2 = sample_df.sem_cluster_by("Course Name", 2, niter=20) | ||
|
||
# Both should produce valid clusterings | ||
assert len(result1["cluster_id"].unique()) == 2 | ||
assert len(result2["cluster_id"].unique()) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from tests.base_test import BaseTest | ||
|
||
|
||
@pytest.fixture | ||
def sample_df(): | ||
return pd.DataFrame({ | ||
"Course Name": [ | ||
"Introduction to Programming", | ||
"Advanced Programming", | ||
"Cooking Basics", | ||
"Advanced Culinary Arts", | ||
"Data Structures", | ||
"Algorithms", | ||
"French Cuisine", | ||
"Italian Cooking" | ||
], | ||
"Department": [ | ||
"CS", "CS", "Culinary", "Culinary", | ||
"CS", "CS", "Culinary", "Culinary" | ||
], | ||
"Level": [ | ||
100, 200, 100, 200, | ||
300, 300, 200, 200 | ||
] | ||
}) | ||
|
||
|
||
class TestSearch(BaseTest): | ||
def test_basic_search(self, sample_df): | ||
"""Test basic semantic search functionality""" | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
result = df.sem_search("Course Name", "programming courses", K=2) | ||
assert len(result) == 2 | ||
assert "Introduction to Programming" in result["Course Name"].values | ||
assert "Advanced Programming" in result["Course Name"].values | ||
|
||
def test_filtered_search_relational(self, sample_df): | ||
"""Test semantic search with relational filter""" | ||
# Index the dataframe | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
|
||
# Apply relational filter and search | ||
filtered_df = df[df["Department"] == "CS"] | ||
result = filtered_df.sem_search("Course Name", "advanced courses", K=2) | ||
|
||
assert len(result) == 2 | ||
# Should only return CS courses | ||
assert all(dept == "CS" for dept in result["Department"]) | ||
assert "Advanced Programming" in result["Course Name"].values | ||
|
||
def test_filtered_search_semantic(self, sample_df): | ||
"""Test semantic search after semantic filter""" | ||
# Index the dataframe | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
|
||
# Apply semantic filter and search | ||
filtered_df = df.sem_filter("{Course Name} is related to cooking") | ||
result = filtered_df.sem_search("Course Name", "advanced level courses", K=2) | ||
|
||
assert len(result) == 2 | ||
# Should only return cooking-related courses | ||
assert all(dept == "Culinary" for dept in result["Department"]) | ||
assert "Advanced Culinary Arts" in result["Course Name"].values | ||
|
||
def test_filtered_search_combined(self, sample_df): | ||
"""Test semantic search with both relational and semantic filters""" | ||
# Index the dataframe | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
|
||
# Apply both filters and search | ||
filtered_df = df[df["Level"] >= 200] # relational filter | ||
filtered_df = filtered_df.sem_filter("{Course Name} is related to computer science") # semantic filter | ||
result = filtered_df.sem_search("Course Name", "data structures and algorithms", K=2) | ||
|
||
assert len(result) == 2 | ||
# Should only return advanced CS courses | ||
assert all(dept == "CS" for dept in result["Department"]) | ||
assert all(level >= 200 for level in result["Level"]) | ||
assert "Data Structures" in result["Course Name"].values | ||
assert "Algorithms" in result["Course Name"].values | ||
|
||
def test_filtered_search_empty_result(self, sample_df): | ||
"""Test semantic search when filter returns empty result""" | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
|
||
# Apply filter that should return no results | ||
filtered_df = df[df["Level"] > 1000] | ||
result = filtered_df.sem_search("Course Name", "any course", K=2) | ||
|
||
assert len(result) == 0 | ||
|
||
def test_filtered_search_with_scores(self, sample_df): | ||
"""Test filtered semantic search with similarity scores""" | ||
df = sample_df.sem_index("Course Name", "course_index") | ||
|
||
filtered_df = df[df["Department"] == "CS"] | ||
result = filtered_df.sem_search( | ||
"Course Name", | ||
"programming courses", | ||
K=2, | ||
return_scores=True | ||
) | ||
|
||
assert "vec_scores_sim_score" in result.columns | ||
assert len(result["vec_scores_sim_score"]) == 2 | ||
# Scores should be between 0 and 1 | ||
assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"]) |