diff --git a/modin/core/dataframe/pandas/partitioning/partition_manager.py b/modin/core/dataframe/pandas/partitioning/partition_manager.py index cb207f64d4e..6cf22b5f7d7 100644 --- a/modin/core/dataframe/pandas/partitioning/partition_manager.py +++ b/modin/core/dataframe/pandas/partitioning/partition_manager.py @@ -915,18 +915,22 @@ def map_partitions_joined_by_column( # step cannot be less than 1 step = max(partitions.shape[0] // column_splits, 1) preprocessed_map_func = cls.preprocess_func(map_func) - kw = { - "num_splits": step, - } result = np.empty(partitions.shape, dtype=object) for i in range( 0, partitions.shape[0], step, ): - joined_column_partitions = cls.column_partitions(partitions[i : i + step]) + partitions_subset = partitions[i : i + step] + # This is necessary when ``partitions.shape[0]`` is not divisible + # by `column_splits` without a remainder. + actual_step = len(partitions_subset) + kw = { + "num_splits": actual_step, + } + joined_column_partitions = cls.column_partitions(partitions_subset) for j in range(partitions.shape[1]): - result[i : i + step, j] = joined_column_partitions[j].apply( + result[i : i + actual_step, j] = joined_column_partitions[j].apply( preprocessed_map_func, *map_func_args if map_func_args is not None else (), **kw, diff --git a/modin/tests/core/storage_formats/pandas/test_internals.py b/modin/tests/core/storage_formats/pandas/test_internals.py index b030fe7b216..5d3e99c0236 100644 --- a/modin/tests/core/storage_formats/pandas/test_internals.py +++ b/modin/tests/core/storage_formats/pandas/test_internals.py @@ -2677,8 +2677,9 @@ def test_dynamic_partitioning(partitioning_scheme, expected_map_approach): expected_method.assert_called() -def test_map_partitions_joined_by_column(): - with context(NPartitions=CpuCount.get() * 2): +@pytest.mark.parametrize("npartitions", [7, CpuCount.get() * 2]) +def test_map_partitions_joined_by_column(npartitions): + with context(NPartitions=npartitions): ncols = MinColumnPartitionSize.get() nrows = MinRowPartitionSize.get() * CpuCount.get() * 2 data = {f"col{i}": np.ones(nrows) for i in range(ncols)}