Skip to content

Commit

Permalink
Added addtional drop nan test case
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherbunn committed Jan 26, 2024
1 parent e1f5457 commit 60a7fad
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions evalml/tests/component_tests/test_drop_nan_rows_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def test_drop_rows_transformer():
assert_frame_equal(fit_transformed_X, X_expected)


@pytest.mark.parametrize("y_is_df", [True, False])
@pytest.mark.parametrize("null_value", [pd.NA, np.NaN])
def test_drop_rows_transformer_retain_ww_schema(null_value):
def test_drop_rows_transformer_retain_ww_schema(null_value, y_is_df):
# Expecting float because of np.NaN values
X = pd.DataFrame(
{"a column": [null_value, 2, 3, 4], "another col": ["a", null_value, "c", "d"]},
Expand All @@ -46,20 +47,47 @@ def test_drop_rows_transformer_retain_ww_schema(null_value):
)
X_expected_schema = X.ww.schema

y = pd.Series([3, 2, 1, null_value])
y = init_series(y, logical_type="IntegerNullable", semantic_tags="y_custom_tag")
if y_is_df:
y = pd.DataFrame(
{"series_a": [3, 2, 1, null_value], "series_b": [1, null_value, 3, 4]},
)
y.ww.init()
y.ww.set_types(
logical_types={
"series_a": "IntegerNullable",
"series_b": "IntegerNullable",
},
semantic_tags={"series_a": "custom_tag_a", "series_b": "custom_tag_b"},
)

y_expected = pd.Series([1], index=[2])
y_expected = init_series(
y_expected,
logical_type="IntegerNullable",
semantic_tags="y_custom_tag",
)
y_expected = pd.DataFrame({"series_a": [1], "series_b": [3]}, index=[2])
y_expected.ww.init()
y_expected.ww.set_types(
logical_types={
"series_a": "IntegerNullable",
"series_b": "IntegerNullable",
},
semantic_tags={"series_a": "custom_tag_a", "series_b": "custom_tag_b"},
)
else:
y = pd.Series([3, 2, 1, null_value])
y = init_series(y, logical_type="IntegerNullable", semantic_tags="y_custom_tag")

y_expected = pd.Series([1], index=[2])
y_expected = init_series(
y_expected,
logical_type="IntegerNullable",
semantic_tags="y_custom_tag",
)
y_expected_schema = y.ww.schema

drop_rows_transformer = DropNaNRowsTransformer()
transformed_X, transformed_y = drop_rows_transformer.fit_transform(X, y)
assert_frame_equal(transformed_X, X_expected)
assert_series_equal(transformed_y, y_expected)
assert _schema_is_equal(transformed_X.ww.schema, X_expected_schema)

if y_is_df:
assert_frame_equal(transformed_y, y_expected)
else:
assert_series_equal(transformed_y, y_expected)
assert transformed_y.ww.schema == y_expected_schema

0 comments on commit 60a7fad

Please sign in to comment.