From ca191167603e92307888e6bf1d369f71d4c406de Mon Sep 17 00:00:00 2001 From: Sid Jha Date: Mon, 9 Dec 2024 22:07:24 -0800 Subject: [PATCH] Small updates --- .github/tests/lm_tests.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 2c772b86..c0f6c531 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -6,7 +6,7 @@ import lotus from lotus.models import LM, SentenceTransformersRM -from lotus.types import SemJoinCascadeArgs +from lotus.types import CascadeArgs ################################################################################ # Setup @@ -286,10 +286,12 @@ def test_filter_cascade(setup_models): def test_join_cascade(setup_models): models = setup_models rm = SentenceTransformersRM(model="intfloat/e5-base-v2") - lotus.settings.configure(lm=models["gpt-4o-mini"], - rm=rm, - min_join_cascade_size=10, # for smaller testings - cascade_IS_random_seed=42) + lotus.settings.configure( + lm=models["gpt-4o-mini"], + rm=rm, + min_join_cascade_size=10, # for smaller testings + cascade_IS_random_seed=42, + ) data1 = { "School": [ @@ -308,37 +310,41 @@ def test_join_cascade(setup_models): "Yale University", "Cornell University", "University of Pennsylvania", - ]} + ] + } data2 = {"School Type": ["Public School", "Private School"]} df1 = pd.DataFrame(data1) df2 = pd.DataFrame(data2) join_instruction = "{School} is a {School Type}" - expected_pairs = [("University of California, Berkeley", "Public School"), ("Stanford University", "Private School")] + expected_pairs = [ + ("University of California, Berkeley", "Public School"), + ("Stanford University", "Private School"), + ] # Cascade join joined_df, stats = df1.sem_join( - df2, join_instruction, - cascade_args=SemJoinCascadeArgs(recall_target=0.7, precision_target=0.7), - return_stats=True) + df2, join_instruction, cascade_args=CascadeArgs(recall_target=0.7, precision_target=0.7), return_stats=True + ) for pair in expected_pairs: school, school_type = pair - exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any() + exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any() assert exists, f"Expected pair {pair} does not exist in the dataframe!" assert stats["join_resolved_by_helper_model"] > 0, stats # All joins resolved by the large model joined_df, stats = df1.sem_join( - df2, join_instruction, - cascade_args=SemJoinCascadeArgs(recall_target=1.0, precision_target=1.0), - return_stats=True) + df2, join_instruction, cascade_args=CascadeArgs(recall_target=1.0, precision_target=1.0), return_stats=True + ) for pair in expected_pairs: school, school_type = pair - exists = ((joined_df['School'] == school) & (joined_df['School Type'] == school_type)).any() + exists = ((joined_df["School"] == school) & (joined_df["School Type"] == school_type)).any() assert exists, f"Expected pair {pair} does not exist in the dataframe!" - assert stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"], stats # helper negative still can still meet the precision target + assert ( + stats["join_resolved_by_large_model"] > stats["join_resolved_by_helper_model"] + ), stats # helper negative still can still meet the precision target assert stats["join_helper_positive"] == 0, stats