Skip to content

Commit

Permalink
no "index" column in aggtarget output (#1020)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes authored Aug 1, 2024
1 parent 48ad2fd commit b78a5f2
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ Minor changes
duplicate column names, now the output names are always the same.
:pr:`1013` by :user:`Jérôme Dockès <jeromedockes>`.

* In some cases :class:`AggJoiner` and :class:`AggTarget` inserted a column in
the output named "index" containing the pandas index of the auxiliary table.
This has been corrected.
:pr:`1020` by :user:`Jérôme Dockès <jeromedockes>`.

Release 0.2.0
=============

Expand Down
1 change: 0 additions & 1 deletion skrub/_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ def fit_transform(self, X, y):
y_[self.main_key_] = X[self.main_key_]

num_operations, categ_operations = split_num_categ_operations(self.operation_)

self.y_ = skrub_px.aggregate(
y_,
key=self.main_key_,
Expand Down
4 changes: 2 additions & 2 deletions skrub/_dataframe/_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def aggregate(

named_agg = {**num_named_agg, **categ_named_agg}
if named_agg:
base_group = table.groupby(key).agg(**named_agg)
base_group = table.groupby(key).agg(**named_agg).reset_index(drop=False)
else:
base_group = None

Expand Down Expand Up @@ -104,7 +104,7 @@ def aggregate(
]
sorted_cols = sorted(base_group.columns)

return base_group[sorted_cols].reset_index(drop=False)
return base_group[sorted_cols]


def get_named_agg(table, cols, operations):
Expand Down
5 changes: 3 additions & 2 deletions skrub/_dataframe/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_simple_agg():
"rating_mean": ("rating", "mean"),
}
expected = main.groupby("movieId").agg(**aggfunc).reset_index()
expected = expected.loc[:, sorted(expected.columns)]
assert_frame_equal(aggregated, expected)


Expand All @@ -49,7 +50,7 @@ def test_value_counts_agg():
"rating_4.0_user": [3.0, 1.0],
"userId": [1, 2],
}
).reset_index(drop=False)
)
assert_frame_equal(aggregated, expected)

aggregated = aggregate(
Expand All @@ -66,7 +67,7 @@ def test_value_counts_agg():
"rating_(3.0, 4.0]_user": [3, 1],
"userId": [1, 2],
}
).reset_index(drop=False)
)
assert_frame_equal(aggregated, expected)


Expand Down
1 change: 0 additions & 1 deletion skrub/tests/test_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def test_agg_target(main_table, y, col_name):
"movieId": [1, 3, 6, 318, 6, 1704],
"rating": [4.0, 4.0, 4.0, 3.0, 2.0, 4.0],
"genre": ["drama", "drama", "comedy", "sf", "comedy", "sf"],
"index": [0, 0, 0, 1, 1, 1],
f"{col_name}_(1.999, 3.0]_user": [0, 0, 0, 2, 2, 2],
f"{col_name}_(3.0, 4.0]_user": [3, 3, 3, 1, 1, 1],
f"{col_name}_2.0_user": [0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
Expand Down

0 comments on commit b78a5f2

Please sign in to comment.