From 2f0e09ee505ba55164b33d62fb317d3d8356052a Mon Sep 17 00:00:00 2001 From: Sid Jha <45739834+sidjha1@users.noreply.github.com> Date: Tue, 12 Nov 2024 22:03:20 -0800 Subject: [PATCH] Support Groups for Sem Agg (#32) See the test for the behavior --- .github/tests/lm_tests.py | 24 ++++++++++++++++++++++++ lotus/sem_ops/sem_agg.py | 11 ++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index e8e46f47..3152c20a 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -136,11 +136,35 @@ def test_agg_then_map(setup_models, model): df = pd.DataFrame(data) agg_instruction = "What is the most common name in {Text}?" agg_df = df.sem_agg(agg_instruction, suffix="draft_output") + assert len(agg_df) == 1 + map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word" cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john" +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_group_by_with_agg(setup_models, model): + lm = setup_models[model] + lotus.settings.configure(lm=lm) + + data = { + "Names": ["Michael", "Anakin", "Luke", "Dwight"], + "Show": ["The Office", "Star Wars", "Star Wars", "The Office"], + } + df = pd.DataFrame(data) + agg_instruction = "Summarize {Names}" + agg_df = df.sem_agg(agg_instruction, suffix="draft_output", group_by=["Show"]) + assert len(agg_df) == 2 + + # Map post-processing + map_instruction = "{draft_output} is a draft answer to the question 'Summarize the names'. Clean up the draft answer is just a comma separated list of names." + cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output") + + assert set(cleaned_df["final_output"].values[0].lower().strip(".,!?\"'").split(", ")) == {"anakin", "luke"} + assert set(cleaned_df["final_output"].values[1].lower().strip(".,!?\"'").split(", ")) == {"michael", "dwight"} + + ################################################################################ # Cascade tests ################################################################################ diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 6fd9e8b0..56a95ff9 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -138,6 +138,7 @@ def __call__( user_instruction: str, all_cols: bool = False, suffix: str = "_output", + group_by: list[str] | None = None, ) -> pd.DataFrame: """ Applies semantic aggregation over a dataframe. @@ -146,7 +147,7 @@ def __call__( user_instruction (str): The user instruction for aggregation. all_cols (bool): Whether to use all columns in the dataframe. Defaults to False. suffix (str): The suffix for the new column. Defaults to "_output". - + group_by (list[str] | None): The columns to group by before aggregation. Each group will be aggregated separately. Returns: pd.DataFrame: The dataframe with the aggregated answer. """ @@ -163,6 +164,14 @@ def __call__( if column not in self._obj.columns: raise ValueError(f"column {column} not found in DataFrame. Given usr instruction: {user_instruction}") + if group_by: + grouped = self._obj.groupby(group_by) + new_df = pd.DataFrame() + for name, group in grouped: + res = group.sem_agg(user_instruction, all_cols, suffix, None) + new_df = pd.concat([new_df, res]) + return new_df + # Sort df by partition_id if it exists if "_lotus_partition_id" in self._obj.columns: self._obj = self._obj.sort_values(by="_lotus_partition_id")