Skip to content

Commit

Permalink
Add RM Tests (#13)
Browse files Browse the repository at this point in the history
Adds RM tests to the CI. Tests the following
1. Cluster by
2. Search (top-K only)
3. Search (reranker only)
4. Search (both top-K + reranker)
5. Dedup
6. Semantic Join
  • Loading branch information
sidjha1 authored Oct 1, 2024
1 parent cb35763 commit 6057141
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 2 deletions.
138 changes: 138 additions & 0 deletions .github/tests/rm_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pandas as pd
import pytest

import lotus
from lotus.models import CrossEncoderModel, E5Model

# Set logger level to DEBUG
lotus.logger.setLevel("DEBUG")


@pytest.fixture
def setup_models():
# Set up embedder and reranker model
rm = E5Model(model="intfloat/e5-small-v2")
reranker = CrossEncoderModel(model="mixedbread-ai/mxbai-rerank-xsmall-v1")
return rm, reranker


def test_cluster_by(setup_models):
rm, _ = setup_models
lotus.settings.configure(rm=rm)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_cluster_by("Course Name", 2)
groups = df.groupby("cluster_id")["Course Name"].apply(set).to_dict()
assert len(groups) == 2, groups
if "Cooking" in groups[0]:
cooking_group = groups[0]
probability_group = groups[1]
else:
cooking_group = groups[1]
probability_group = groups[0]

assert cooking_group == {"Cooking", "Food Sciences"}, groups
assert probability_group == {"Probability and Random Processes", "Optimization Methods in Engineering"}, groups


def test_search_rm_only(setup_models):
rm, _ = setup_models
lotus.settings.configure(rm=rm)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_search("Course Name", "Optimization", K=1)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]


def test_search_reranker_only(setup_models):
_, reranker = setup_models
lotus.settings.configure(reranker=reranker)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_search("Course Name", "Optimization", n_rerank=2)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering", "Probability and Random Processes"]


def test_search(setup_models):
rm, reranker = setup_models
lotus.settings.configure(rm=rm, reranker=reranker)

data = {
"Course Name": [
"Probability and Random Processes",
"Cooking",
"Food Sciences",
"Optimization Methods in Engineering",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Course Name", "index_dir")
df = df.sem_search("Course Name", "Optimization", K=2, n_rerank=1)
assert df["Course Name"].tolist() == ["Optimization Methods in Engineering"]


def test_dedup(setup_models):
rm, _ = setup_models
lotus.settings.configure(rm=rm)
data = {
"Text": [
"Probability and Random Processes",
"Probability and Markov Chains",
"Harry Potter",
"Harry James Potter",
]
}
df = pd.DataFrame(data)
df = df.sem_index("Text", "index_dir").sem_dedup("Text", threshold=0.85)
kept = df["Text"].tolist()
kept.sort()
assert len(kept) == 2, kept
assert "Harry" in kept[0], kept
assert "Probability" in kept[1], kept


def test_sim_join(setup_models):
rm, _ = setup_models
lotus.settings.configure(rm=rm)

data1 = {
"Course Name": [
"History of the Atlantic World",
"Riemannian Geometry",
]
}

data2 = {"Skill": ["Math", "History"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2).sem_index("Skill", "index_dir")
joined_df = df1.sem_sim_join(df2, left_on="Course Name", right_on="Skill", K=1)
joined_pairs = set(zip(joined_df["Course Name"], joined_df["Skill"]))
expected_pairs = {("History of the Atlantic World", "History"), ("Riemannian Geometry", "Math")}
assert joined_pairs == expected_pairs, joined_pairs
6 changes: 4 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: ruff check .

test:
name: LM Tests
name: Python Tests
runs-on: ubuntu-latest
timeout-minutes: 5

Expand All @@ -57,4 +57,6 @@ jobs:
- name: Run Python tests
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: pytest .github/tests/lm_tests.py
run: |
pytest .github/tests/lm_tests.py
pytest .github/tests/rm_tests.py

0 comments on commit 6057141

Please sign in to comment.