Skip to content

Commit

Permalink
Small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sidjha1 committed Dec 10, 2024
1 parent 833511f commit ca19116
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import lotus
from lotus.models import LM, SentenceTransformersRM
from lotus.types import SemJoinCascadeArgs
from lotus.types import CascadeArgs

################################################################################
# Setup
Expand Down Expand Up @@ -286,10 +286,12 @@ def test_filter_cascade(setup_models):
def test_join_cascade(setup_models):
models = setup_models
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
lotus.settings.configure(lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42)
lotus.settings.configure(
lm=models["gpt-4o-mini"],
rm=rm,
min_join_cascade_size=10, # for smaller testings
cascade_IS_random_seed=42,
)

data1 = {
"School": [
Expand All @@ -308,37 +310,41 @@ def test_join_cascade(setup_models):
"Yale University",
"Cornell University",
"University of Pennsylvania",
]}
]
}
data2 = {"School Type": ["Public School", "Private School"]}

df1 = pd.DataFrame(data1)
df2 = pd.DataFrame(data2)
join_instruction = "{School} is a {School Type}"
expected_pairs = [("University of California, Berkeley", "Public School"), ("Stanford University", "Private School")]
expected_pairs = [
("University of California, Berkeley", "Public School"),
("Stanford University", "Private School"),
]

# Cascade join
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=0.7, precision_target=0.7), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_helper_model"] > 0, stats

# All joins resolved by the large model
joined_df, stats = df1.sem_join(
df2, join_instruction,
cascade_args=SemJoinCascadeArgs(recall_target=1.0, precision_target=1.0),
return_stats=True)
df2, join_instruction, cascade_args=CascadeArgs(recall_target=1.0, precision_target=1.0), return_stats=True
)

for pair in expected_pairs:
school, school_type = pair
exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any()
exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any()
assert exists, f"Expected pair {pair} does not exist in the dataframe!"
assert stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"], stats # helper negative still can still meet the precision target
assert (
stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"]
), stats # helper negative still can still meet the precision target
assert stats["join_helper_positive"] == 0, stats


Expand Down

0 comments on commit ca19116

Please sign in to comment.