Skip to content

Commit

Permalink
[SPARK-46500][PS][TESTS] Reorganize FrameParityPivotTests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Reorganize `FrameParityPivotTests`: break `test_pivot_table` into mutiple tests

### Why are the changes needed?
this test is slow
```
Starting test(python3.9): pyspark.pandas.tests.connect.computation.test_parity_pivot (temp output: /__w/spark/spark/python/target/5f37e442-9037-47cc-8c6b-e9a273299d0d/python3.9__pyspark.pandas.tests.connect.computation.test_parity_pivot__ozvdx_ay.log)
Finished test(python3.9): pyspark.pandas.tests.connect.computation.test_parity_pivot (524s)
```

### Does this PR introduce _any_ user-facing change?
no, test only

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#44478 from zhengruifeng/ps_test_pivot_multi.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Dec 25, 2023
1 parent eae0929 commit d85ad1c
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 145 deletions.
8 changes: 8 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,10 @@ def __hash__(self):
"pyspark.pandas.tests.computation.test_melt",
"pyspark.pandas.tests.computation.test_missing_data",
"pyspark.pandas.tests.computation.test_pivot",
"pyspark.pandas.tests.computation.test_pivot_table",
"pyspark.pandas.tests.computation.test_pivot_table_adv",
"pyspark.pandas.tests.computation.test_pivot_table_multi_idx",
"pyspark.pandas.tests.computation.test_pivot_table_multi_idx_adv",
"pyspark.pandas.tests.computation.test_stats",
"pyspark.pandas.tests.frame.test_attrs",
"pyspark.pandas.tests.frame.test_axis",
Expand Down Expand Up @@ -1162,6 +1166,10 @@ def __hash__(self):
python_test_goals=[
# pandas-on-Spark unittests
"pyspark.pandas.tests.connect.computation.test_parity_pivot",
"pyspark.pandas.tests.connect.computation.test_parity_pivot_table",
"pyspark.pandas.tests.connect.computation.test_parity_pivot_table_adv",
"pyspark.pandas.tests.connect.computation.test_parity_pivot_table_multi_idx",
"pyspark.pandas.tests.connect.computation.test_parity_pivot_table_multi_idx_adv",
"pyspark.pandas.tests.connect.computation.test_parity_stats",
"pyspark.pandas.tests.connect.indexes.test_parity_base_slow",
"pyspark.pandas.tests.connect.frame.test_parity_interpolate",
Expand Down
149 changes: 5 additions & 144 deletions python/pyspark/pandas/tests/computation/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,149 +61,6 @@ def test_pivot_table_dtypes(self):
# columns="a", values="b", fill_value=999).dtypes, pdf.pivot_table(index=['e', 'c'],
# columns="a", values="b", fill_value=999).dtypes)

def test_pivot_table(self):
pdf = pd.DataFrame(
{
"a": [4, 2, 3, 4, 8, 6],
"b": [1, 2, 2, 4, 2, 4],
"e": [10, 20, 20, 40, 20, 40],
"c": [1, 2, 9, 4, 7, 4],
"d": [-1, -2, -3, -4, -5, -6],
},
index=np.random.rand(6),
)
psdf = ps.from_pandas(pdf)

# Checking if both DataFrames have the same results
self.assert_eq(
psdf.pivot_table(columns="a", values="b").sort_index(),
pdf.pivot_table(columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values="b").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc="sum"
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc="sum"
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum"
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum"
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(),
pdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["e", "c"], columns="a", values="b", fill_value=999
).sort_index(),
pdf.pivot_table(index=["e", "c"], columns="a", values="b", fill_value=999).sort_index(),
almost=True,
)

# multi-index columns
columns = pd.MultiIndex.from_tuples(
[("x", "a"), ("x", "b"), ("y", "e"), ("z", "c"), ("w", "d")]
)
pdf.columns = columns
psdf.columns = columns

self.assert_eq(
psdf.pivot_table(columns=("x", "a"), values=("x", "b")).sort_index(),
pdf.pivot_table(columns=[("x", "a")], values=[("x", "b")]).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=[("z", "c")], columns=("x", "a"), values=[("x", "b")]
).sort_index(),
pdf.pivot_table(
index=[("z", "c")], columns=[("x", "a")], values=[("x", "b")]
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=[("z", "c")], columns=("x", "a"), values=[("x", "b"), ("y", "e")]
).sort_index(),
pdf.pivot_table(
index=[("z", "c")], columns=[("x", "a")], values=[("x", "b"), ("y", "e")]
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=[("z", "c")], columns=("x", "a"), values=[("x", "b"), ("y", "e"), ("w", "d")]
).sort_index(),
pdf.pivot_table(
index=[("z", "c")],
columns=[("x", "a")],
values=[("x", "b"), ("y", "e"), ("w", "d")],
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=[("z", "c")],
columns=("x", "a"),
values=[("x", "b"), ("y", "e")],
aggfunc={("x", "b"): "mean", ("y", "e"): "sum"},
).sort_index(),
pdf.pivot_table(
index=[("z", "c")],
columns=[("x", "a")],
values=[("x", "b"), ("y", "e")],
aggfunc={("x", "b"): "mean", ("y", "e"): "sum"},
).sort_index(),
almost=True,
)

def test_pivot_table_and_index(self):
# https://github.com/databricks/koalas/issues/805
pdf = pd.DataFrame(
Expand Down Expand Up @@ -332,7 +189,11 @@ def test_pivot_table_errors(self):
psdf.pivot_table(index=["C"], columns="A", values="B", aggfunc={"B": "mean"})


class FramePivotTests(FramePivotMixin, ComparisonTestBase, SQLTestUtils):
class FramePivotTests(
FramePivotMixin,
ComparisonTestBase,
SQLTestUtils,
):
pass


Expand Down
93 changes: 93 additions & 0 deletions python/pyspark/pandas/tests/computation/test_pivot_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

import numpy as np
import pandas as pd

from pyspark import pandas as ps
from pyspark.testing.pandasutils import ComparisonTestBase
from pyspark.testing.sqlutils import SQLTestUtils


class PivotTableMixin:
def test_pivot_table(self):
pdf = pd.DataFrame(
{
"a": [4, 2, 3, 4, 8, 6],
"b": [1, 2, 2, 4, 2, 4],
"e": [10, 20, 20, 40, 20, 40],
"c": [1, 2, 9, 4, 7, 4],
"d": [-1, -2, -3, -4, -5, -6],
},
index=np.random.rand(6),
)
psdf = ps.from_pandas(pdf)

self.assert_eq(
psdf.pivot_table(columns="a", values="b").sort_index(),
pdf.pivot_table(columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values="b").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values="b", aggfunc="sum").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(),
pdf.pivot_table(index=["c"], columns="a", values=["b"], aggfunc="sum").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc="sum"
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc="sum"
).sort_index(),
almost=True,
)


class PivotTableTests(
PivotTableMixin,
ComparisonTestBase,
SQLTestUtils,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.computation.test_pivot_table import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
93 changes: 93 additions & 0 deletions python/pyspark/pandas/tests/computation/test_pivot_table_adv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

import numpy as np
import pandas as pd

from pyspark import pandas as ps
from pyspark.testing.pandasutils import ComparisonTestBase
from pyspark.testing.sqlutils import SQLTestUtils


class PivotTableAdvMixin:
def test_pivot_table(self):
pdf = pd.DataFrame(
{
"a": [4, 2, 3, 4, 8, 6],
"b": [1, 2, 2, 4, 2, 4],
"e": [10, 20, 20, 40, 20, 40],
"c": [1, 2, 9, 4, 7, 4],
"d": [-1, -2, -3, -4, -5, -6],
},
index=np.random.rand(6),
)
psdf = ps.from_pandas(pdf)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum"
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e", "d"], aggfunc="sum"
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}
).sort_index(),
pdf.pivot_table(
index=["c"], columns="a", values=["b", "e"], aggfunc={"b": "mean", "e": "sum"}
).sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(),
pdf.pivot_table(index=["e", "c"], columns="a", values="b").sort_index(),
almost=True,
)

self.assert_eq(
psdf.pivot_table(
index=["e", "c"], columns="a", values="b", fill_value=999
).sort_index(),
pdf.pivot_table(index=["e", "c"], columns="a", values="b", fill_value=999).sort_index(),
almost=True,
)


class PivotTableAdvTests(
PivotTableAdvMixin,
ComparisonTestBase,
SQLTestUtils,
):
pass


if __name__ == "__main__":
from pyspark.pandas.tests.computation.test_pivot_table_adv import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit d85ad1c

Please sign in to comment.