Skip to content

Commit

Permalink
Support Groups for Sem Agg (#32)
Browse files Browse the repository at this point in the history
See the test for the behavior
  • Loading branch information
sidjha1 authored Nov 13, 2024
1 parent 30ff6f7 commit 2f0e09e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
24 changes: 24 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,35 @@ def test_agg_then_map(setup_models, model):
df = pd.DataFrame(data)
agg_instruction = "What is the most common name in {Text}?"
agg_df = df.sem_agg(agg_instruction, suffix="draft_output")
assert len(agg_df) == 1

map_instruction = "{draft_output} is a draft answer to the question 'What is the most common name?'. Clean up the draft answer so that there is just a single name. Your answer MUST be on word"
cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output")
assert cleaned_df["final_output"].values[0].lower().strip(".,!?\"'") == "john"


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_group_by_with_agg(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm)

data = {
"Names": ["Michael", "Anakin", "Luke", "Dwight"],
"Show": ["The Office", "Star Wars", "Star Wars", "The Office"],
}
df = pd.DataFrame(data)
agg_instruction = "Summarize {Names}"
agg_df = df.sem_agg(agg_instruction, suffix="draft_output", group_by=["Show"])
assert len(agg_df) == 2

# Map post-processing
map_instruction = "{draft_output} is a draft answer to the question 'Summarize the names'. Clean up the draft answer is just a comma separated list of names."
cleaned_df = agg_df.sem_map(map_instruction, suffix="final_output")

assert set(cleaned_df["final_output"].values[0].lower().strip(".,!?\"'").split(", ")) == {"anakin", "luke"}
assert set(cleaned_df["final_output"].values[1].lower().strip(".,!?\"'").split(", ")) == {"michael", "dwight"}


################################################################################
# Cascade tests
################################################################################
Expand Down
11 changes: 10 additions & 1 deletion lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __call__(
user_instruction: str,
all_cols: bool = False,
suffix: str = "_output",
group_by: list[str] | None = None,
) -> pd.DataFrame:
"""
Applies semantic aggregation over a dataframe.
Expand All @@ -146,7 +147,7 @@ def __call__(
user_instruction (str): The user instruction for aggregation.
all_cols (bool): Whether to use all columns in the dataframe. Defaults to False.
suffix (str): The suffix for the new column. Defaults to "_output".
group_by (list[str] | None): The columns to group by before aggregation. Each group will be aggregated separately.
Returns:
pd.DataFrame: The dataframe with the aggregated answer.
"""
Expand All @@ -163,6 +164,14 @@ def __call__(
if column not in self._obj.columns:
raise ValueError(f"column {column} not found in DataFrame. Given usr instruction: {user_instruction}")

if group_by:
grouped = self._obj.groupby(group_by)
new_df = pd.DataFrame()
for name, group in grouped:
res = group.sem_agg(user_instruction, all_cols, suffix, None)
new_df = pd.concat([new_df, res])
return new_df

# Sort df by partition_id if it exists
if "_lotus_partition_id" in self._obj.columns:
self._obj = self._obj.sort_values(by="_lotus_partition_id")
Expand Down

0 comments on commit 2f0e09e

Please sign in to comment.