Skip to content

Commit

Permalink
initial scaffolding for adding vector store / vector database integra…
Browse files Browse the repository at this point in the history
…tion (#76)
  • Loading branch information
AmoghTantradi authored Jan 13, 2025
1 parent 8a207aa commit 4d4ca82
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lotus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import lotus.dtype_extensions
import lotus.models
import lotus.vector_store
import lotus.nl_expression
import lotus.templates
import lotus.utils
Expand Down Expand Up @@ -44,6 +45,7 @@
"templates",
"logger",
"models",
"vector_store",
"utils",
"dtype_extensions",
]
2 changes: 2 additions & 0 deletions lotus/settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import lotus.models
import lotus.vector_store
from lotus.types import SerializationFormat

# NOTE: Settings class is not thread-safe
Expand All @@ -10,6 +11,7 @@ class Settings:
rm: lotus.models.RM | None = None
helper_lm: lotus.models.LM | None = None
reranker: lotus.models.Reranker | None = None
vs: lotus.vector_store.VS | None = None

# Cache settings
enable_cache: bool = False
Expand Down
7 changes: 7 additions & 0 deletions lotus/vector_store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from lotus.vector_store.vs import VS
from lotus.vector_store.weaviate_vs import WeaviateVS
from lotus.vector_store.pinecone_vs import PineconeVS
from lotus.vector_store.chroma_vs import ChromaVS
from lotus.vector_store.qdrant_vs import QdrantVS

__all__ = ["VS", "WeaviateVS", "PineconeVS", "ChromaVS", "QdrantVS"]
16 changes: 16 additions & 0 deletions lotus/vector_store/chroma_vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lotus.vector_store.vs import VS


class ChromaVS(VS):
def __init__(self):
try:
import chromadb
except ImportError:
chromadb = None


if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromaVS. Install it with `pip install chromadb`",
)
pass
17 changes: 17 additions & 0 deletions lotus/vector_store/pinecone_vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from lotus.vector_store.vs import VS


class PineconeVS(VS):
def __init__(self):
try:
import pinecone
except ImportError:
pinecone = None


if pinecone is None:
raise ImportError(
"The pinecone library is required to use PineconeVS. Install it with `pip install pinecone`",
)

pass
16 changes: 16 additions & 0 deletions lotus/vector_store/qdrant_vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lotus.vector_store.vs import VS


class QdrantVS(VS):
def __init__(self):
try:
import qdrant_client
except ImportError:
qdrant_client = None


if qdrant_client is None:
raise ImportError(
"The qdrant library is required to use QdrantVS. Install it with `pip install qdrant_client`",
)
pass
32 changes: 32 additions & 0 deletions lotus/vector_store/vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import Any

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from PIL import Image

from lotus.types import RMOutput


class VS(ABC):
"""Abstract class for vector stores."""

def __init__(self) -> None:
pass

@abstractmethod
def index(self, docs: pd.Series, index_dir):
pass

@abstractmethod
def search(self,
queries: pd.Series | str | Image.Image | list | NDArray[np.float64],
K:int,
**kwargs: dict[str, Any],
) -> RMOutput:
pass

@abstractmethod
def get_vectors_from_index(self, collection_name:str, ids: list[int]) -> NDArray[np.float64]:
pass
13 changes: 13 additions & 0 deletions lotus/vector_store/weaviate_vs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from lotus.vector_store.vs import VS


class WeaviateVS(VS):
def __init__(self):
try:
import weaviate
except ImportError:
weaviate = None

if weaviate is None:
raise ImportError("Please install the weaviate client")
pass
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ numpy==1.26.4
pandas==2.2.2
sentence-transformers==3.0.1
tiktoken==0.7.0
tqdm==4.66.4
tqdm==4.66.4
weaviate-client==4.10.2
pinecone==5.4.2
chromadb==0.6.2
qdrant-client==1.12.2
104 changes: 104 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pandas as pd
import pytest

from tests.base_test import BaseTest


@pytest.fixture
def sample_df():
return pd.DataFrame({
"Course Name": [
"Probability and Random Processes",
"Statistics and Data Analysis",
"Cooking Basics",
"Advanced Culinary Arts",
"Digital Circuit Design",
"Computer Architecture"
]
})


class TestClusterBy(BaseTest):
def test_basic_clustering(self, sample_df):
"""Test basic clustering functionality with 2 clusters"""
result = sample_df.sem_cluster_by("Course Name", 2)
assert "cluster_id" in result.columns
assert len(result["cluster_id"].unique()) == 2
assert len(result) == len(sample_df)


# Get the two clusters
cluster_0_courses = set(result[result["cluster_id"] == 0]["Course Name"])
cluster_1_courses = set(result[result["cluster_id"] == 1]["Course Name"])

# Define the expected course groupings
tech_courses = {
"Probability and Random Processes",
"Statistics and Data Analysis",
"Digital Circuit Design",
"Computer Architecture"
}
culinary_courses = {
"Cooking Basics",
"Advanced Culinary Arts"
}

# Check that one cluster contains tech courses and the other contains culinary courses
assert (cluster_0_courses == tech_courses and cluster_1_courses == culinary_courses) or \
(cluster_1_courses == tech_courses and cluster_0_courses == culinary_courses), \
"Clusters don't match expected course groupings"

def test_clustering_with_more_clusters(self, sample_df):
"""Test clustering with more clusters than necessary"""
result = sample_df.sem_cluster_by("Course Name", 3)
assert len(result["cluster_id"].unique()) == 3
assert len(result) == len(sample_df)

def test_clustering_with_single_cluster(self, sample_df):
"""Test clustering with single cluster"""
result = sample_df.sem_cluster_by("Course Name", 1)
assert len(result["cluster_id"].unique()) == 1
assert result["cluster_id"].iloc[0] == 0

def test_clustering_with_invalid_column(self, sample_df):
"""Test clustering with non-existent column"""
with pytest.raises(ValueError, match="Column .* not found in DataFrame"):
sample_df.sem_cluster_by("NonExistentColumn", 2)

def test_clustering_with_empty_dataframe(self):
"""Test clustering on empty dataframe"""
empty_df = pd.DataFrame(columns=["Course Name"])
result = empty_df.sem_cluster_by("Course Name", 2)
assert len(result) == 0
assert "cluster_id" in result.columns

def test_clustering_similar_items(self, sample_df):
"""Test that similar items are clustered together"""
result = sample_df.sem_cluster_by("Course Name", 3)

# Get cluster IDs for similar courses
stats_cluster = result[result["Course Name"].str.contains("Statistics")]["cluster_id"].iloc[0]
prob_cluster = result[result["Course Name"].str.contains("Probability")]["cluster_id"].iloc[0]

# Similar courses should be in the same cluster
assert stats_cluster == prob_cluster

cooking_cluster = result[result["Course Name"].str.contains("Cooking")]["cluster_id"].iloc[0]
culinary_cluster = result[result["Course Name"].str.contains("Culinary")]["cluster_id"].iloc[0]

assert cooking_cluster == culinary_cluster

def test_clustering_with_verbose(self, sample_df):
"""Test clustering with verbose output"""
result = sample_df.sem_cluster_by("Course Name", 2, verbose=True)
assert "cluster_id" in result.columns
assert len(result["cluster_id"].unique()) == 2

def test_clustering_with_iterations(self, sample_df):
"""Test clustering with different iteration counts"""
result1 = sample_df.sem_cluster_by("Course Name", 2, niter=5)
result2 = sample_df.sem_cluster_by("Course Name", 2, niter=20)

# Both should produce valid clusterings
assert len(result1["cluster_id"].unique()) == 2
assert len(result2["cluster_id"].unique()) == 2
110 changes: 110 additions & 0 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pandas as pd
import pytest

from tests.base_test import BaseTest


@pytest.fixture
def sample_df():
return pd.DataFrame({
"Course Name": [
"Introduction to Programming",
"Advanced Programming",
"Cooking Basics",
"Advanced Culinary Arts",
"Data Structures",
"Algorithms",
"French Cuisine",
"Italian Cooking"
],
"Department": [
"CS", "CS", "Culinary", "Culinary",
"CS", "CS", "Culinary", "Culinary"
],
"Level": [
100, 200, 100, 200,
300, 300, 200, 200
]
})


class TestSearch(BaseTest):
def test_basic_search(self, sample_df):
"""Test basic semantic search functionality"""
df = sample_df.sem_index("Course Name", "course_index")
result = df.sem_search("Course Name", "programming courses", K=2)
assert len(result) == 2
assert "Introduction to Programming" in result["Course Name"].values
assert "Advanced Programming" in result["Course Name"].values

def test_filtered_search_relational(self, sample_df):
"""Test semantic search with relational filter"""
# Index the dataframe
df = sample_df.sem_index("Course Name", "course_index")

# Apply relational filter and search
filtered_df = df[df["Department"] == "CS"]
result = filtered_df.sem_search("Course Name", "advanced courses", K=2)

assert len(result) == 2
# Should only return CS courses
assert all(dept == "CS" for dept in result["Department"])
assert "Advanced Programming" in result["Course Name"].values

def test_filtered_search_semantic(self, sample_df):
"""Test semantic search after semantic filter"""
# Index the dataframe
df = sample_df.sem_index("Course Name", "course_index")

# Apply semantic filter and search
filtered_df = df.sem_filter("{Course Name} is related to cooking")
result = filtered_df.sem_search("Course Name", "advanced level courses", K=2)

assert len(result) == 2
# Should only return cooking-related courses
assert all(dept == "Culinary" for dept in result["Department"])
assert "Advanced Culinary Arts" in result["Course Name"].values

def test_filtered_search_combined(self, sample_df):
"""Test semantic search with both relational and semantic filters"""
# Index the dataframe
df = sample_df.sem_index("Course Name", "course_index")

# Apply both filters and search
filtered_df = df[df["Level"] >= 200] # relational filter
filtered_df = filtered_df.sem_filter("{Course Name} is related to computer science") # semantic filter
result = filtered_df.sem_search("Course Name", "data structures and algorithms", K=2)

assert len(result) == 2
# Should only return advanced CS courses
assert all(dept == "CS" for dept in result["Department"])
assert all(level >= 200 for level in result["Level"])
assert "Data Structures" in result["Course Name"].values
assert "Algorithms" in result["Course Name"].values

def test_filtered_search_empty_result(self, sample_df):
"""Test semantic search when filter returns empty result"""
df = sample_df.sem_index("Course Name", "course_index")

# Apply filter that should return no results
filtered_df = df[df["Level"] > 1000]
result = filtered_df.sem_search("Course Name", "any course", K=2)

assert len(result) == 0

def test_filtered_search_with_scores(self, sample_df):
"""Test filtered semantic search with similarity scores"""
df = sample_df.sem_index("Course Name", "course_index")

filtered_df = df[df["Department"] == "CS"]
result = filtered_df.sem_search(
"Course Name",
"programming courses",
K=2,
return_scores=True
)

assert "vec_scores_sim_score" in result.columns
assert len(result["vec_scores_sim_score"]) == 2
# Scores should be between 0 and 1
assert all(0 <= score <= 1 for score in result["vec_scores_sim_score"])

0 comments on commit 4d4ca82

Please sign in to comment.