Skip to content

Commit

Permalink
🙈 fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Oct 23, 2023
1 parent 8dee606 commit 2a19f41
Showing 1 changed file with 12 additions and 21 deletions.
33 changes: 12 additions & 21 deletions tests/test_features_feature_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_single_series_group_by_feature_collection(dummy_group_data, group_by):
if "consecutive" in group_by: # group_by_consecutive
result_data_counts = res_df.groupby("store")["number_sold__sum__w=manual"].sum()
else: # group_by_all
result_data_counts = res_df["number_sold__sum"]
result_data_counts = res_df["number_sold__sum__w=manual"]

for index in data_counts.index:
assert data_counts[index] == result_data_counts[index]
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_group_by_feature_collection_with_warnings(dummy_group_data, group_by):
if "consecutive" in group_by: # group_by_consecutive
result_data_counts = res_df.groupby("store")["number_sold__sum__w=manual"].sum()
else: # group_by_all
result_data_counts = res_df["number_sold__sum"]
result_data_counts = res_df["number_sold__sum__w=manual"]

for index in data_counts.index:
assert data_counts[index] == result_data_counts[index]
Expand Down Expand Up @@ -160,15 +160,14 @@ def test_single_series_multiple_features_group_by(dummy_group_data, group_by, n_
data_count_min = dummy_group_data.groupby("store")["number_sold"].min()
data_count_max = dummy_group_data.groupby("store")["number_sold"].max()

postfix = "__w=manual" if "consecutive" in group_by else ""
grouped_res_df_sum = (
res_df.reset_index().groupby("store")["number_sold__sum" + postfix].sum()
res_df.reset_index().groupby("store")["number_sold__sum__w=manual"].sum()
)
grouped_res_df_min = (
res_df.reset_index().groupby("store")["number_sold__amin" + postfix].min()
res_df.reset_index().groupby("store")["number_sold__amin__w=manual"].min()
)
grouped_res_df_max = (
res_df.reset_index().groupby("store")["number_sold__amax" + postfix].max()
res_df.reset_index().groupby("store")["number_sold__amax__w=manual"].max()
)

def assert_results(data, res_data):
Expand Down Expand Up @@ -210,11 +209,10 @@ def sum_2(x: np.ndarray, y: np.ndarray) -> float:

assert_frame_equal(concatted_df, res_df)

postfix = "__w=manual" if "consecutive" in group_by else ""
assert all(
res_df["number_sold__sum" + postfix].values
+ res_df["product__sum" + postfix].values
== res_df["number_sold|product__sum_2" + postfix].values
res_df["number_sold__sum__w=manual"].values
+ res_df["product__sum__w=manual"].values
== res_df["number_sold|product__sum_2__w=manual"].values
)


Expand Down Expand Up @@ -247,10 +245,9 @@ def test_group_by_with_nan_values(dummy_group_data, group_by):

assert_frame_equal(concatted_df, res_df)

postfix = "__w=manual" if "consecutive" in group_by else ""
assert (
dummy_group_data["number_sold"].sum()
> res_df["number_sold__sum" + postfix].sum()
> res_df["number_sold__sum__w=manual"].sum()
)


Expand Down Expand Up @@ -307,15 +304,13 @@ def test_group_by_with_unequal_lengths(group_by):
res_list2 = fc.calculate(
[s_group2, s_val2], return_df=True, n_jobs=1, **{group_by: "user_id"}
)
col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
correct_res_list = fc.calculate(
[s_group, s_val2], return_df=True, n_jobs=1, **{group_by: "user_id"}
)

for c in res_list.columns:
# compare (compare_col) only with nan-safe col in case of group_by_all
compare_col = c if "consecutive" in group_by else col
compare_col = c if "consecutive" in group_by else "count__nansum__w=manual"
assert np.all(
res_list[c]
== res_list2.loc[res_list.index, compare_col].astype(res_list.dtypes[c])
Expand Down Expand Up @@ -346,10 +341,8 @@ def test_group_by_non_aligned_indices(group_by):
).reset_index()
grouped_non_nan_df_sums = non_nan_df.groupby("groups").sum()

col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
new_res_list = pd.DataFrame(
{"groups": res_list["user_id"], "values": res_list[col]}
{"groups": res_list["user_id"], "values": res_list["count__nansum__w=manual"]}
)
new_res_list = new_res_list.set_index("groups")

Expand Down Expand Up @@ -395,9 +388,7 @@ def test_group_by_with_numeric_index(group_by):
s_df = pd.DataFrame({"groups": s_group, "values": s_val})

data_counts = s_df.groupby("groups")["values"].sum()
col = "count__nansum"
col += "__w=manual" if "consecutive" in group_by else ""
result_data_counts = res_df.groupby("user_id")[col].sum()
result_data_counts = res_df.groupby("user_id")["count__nansum__w=manual"].sum()

for index in data_counts.index:
assert data_counts[index] == result_data_counts[index]
Expand Down

0 comments on commit 2a19f41

Please sign in to comment.