Skip to content

Commit

Permalink
FEAT-#7254: Support right merge/join (#7226)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
Co-authored-by: Iaroslav Igoshev <[email protected]>
  • Loading branch information
anmyachev and YarShev authored May 13, 2024
1 parent 9992b12 commit 1c0d9a6
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 51 deletions.
5 changes: 4 additions & 1 deletion modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3321,8 +3321,11 @@ def _extract_partitions(self):
if self._partitions.size > 0:
return self._partitions
else:
dtypes = None
if self.has_materialized_dtypes:
dtypes = self.dtypes
return self._partition_mgr_cls.create_partition_from_metadata(
index=self.index, columns=self.columns
index=self.index, columns=self.columns, dtypes=dtypes
)

@lazy_metadata_decorator(apply_axis="both")
Expand Down
12 changes: 10 additions & 2 deletions modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -183,12 +183,18 @@ def preprocess_func(cls, map_func):
# END Abstract Methods

@classmethod
def create_partition_from_metadata(cls, **metadata):
def create_partition_from_metadata(
cls, dtypes: Optional[pandas.Series] = None, **metadata
):
"""
Create NumPy array of partitions that holds an empty dataframe with given metadata.
Parameters
----------
dtypes : pandas.Series, optional
Column dtypes.
Upon creating a pandas DataFrame from `metadata` we call `astype` since
pandas doesn't allow to pass a list of dtypes directly in the constructor.
**metadata : dict
Metadata that has to be wrapped in a partition.
Expand All @@ -198,6 +204,8 @@ def create_partition_from_metadata(cls, **metadata):
A NumPy 2D array of a single partition which contains the data.
"""
metadata_dataframe = pandas.DataFrame(**metadata)
if dtypes is not None:
metadata_dataframe = metadata_dataframe.astype(dtypes)
return np.array([[cls._partition_class.put(metadata_dataframe)]])

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from . import doc_utils

if TYPE_CHECKING:
from typing_extensions import Self

# TODO: should be ModinDataframe
# https://github.com/modin-project/modin/issues/7244
from modin.core.dataframe.pandas.dataframe.dataframe import PandasDataframe
Expand Down Expand Up @@ -158,7 +160,7 @@ def __wrap_in_qc(self, obj):
else:
return obj

def default_to_pandas(self, pandas_op, *args, **kwargs):
def default_to_pandas(self, pandas_op, *args, **kwargs) -> Self:
"""
Do fallback to pandas for the passed function.
Expand Down Expand Up @@ -4467,7 +4469,7 @@ def write_items(df, broadcasted_items):
# END Abstract methods for QueryCompiler

@cached_property
def __constructor__(self) -> type[BaseQueryCompiler]:
def __constructor__(self) -> type[Self]:
"""
Get query compiler constructor.
Expand Down
49 changes: 41 additions & 8 deletions modin/core/storage_formats/pandas/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import pandas
from pandas.core.dtypes.common import is_list_like
Expand Down Expand Up @@ -103,7 +103,7 @@ def func(left, right):
@classmethod
def row_axis_merge(
cls, left: PandasQueryCompiler, right: PandasQueryCompiler, kwargs: dict
):
) -> PandasQueryCompiler:
"""
Execute merge using row-axis implementation.
Expand All @@ -126,10 +126,25 @@ def row_axis_merge(
right_index = kwargs.get("right_index", False)
sort = kwargs.get("sort", False)

if how in ["left", "inner"] and left_index is False and right_index is False:
if (
(
how in ["left", "inner"]
or (how == "right" and right._modin_frame._partitions.size != 0)
)
and left_index is False
and right_index is False
):
kwargs["sort"] = False

def should_keep_index(left, right):
reverted = False
if how == "right":
left, right = right, left
reverted = True

def should_keep_index(
left: PandasQueryCompiler,
right: PandasQueryCompiler,
) -> bool:
keep_index = False
if left_on is not None and right_on is not None:
keep_index = any(
Expand All @@ -144,8 +159,14 @@ def should_keep_index(left, right):
)
return keep_index

def map_func(left, right): # pragma: no cover
return pandas.merge(left, right, **kwargs)
def map_func(
left, right, kwargs=kwargs
) -> pandas.DataFrame: # pragma: no cover
if reverted:
df = pandas.merge(right, left, **kwargs)
else:
df = pandas.merge(left, right, **kwargs)
return df

# Want to ensure that these are python lists
if left_on is not None and right_on is not None:
Expand All @@ -156,7 +177,11 @@ def map_func(left, right): # pragma: no cover

right_to_broadcast = right._modin_frame.combine()
new_columns, new_dtypes = cls._compute_result_metadata(
left, right, on, left_on, right_on, kwargs.get("suffixes", ("_x", "_y"))
*((left, right) if not reverted else (right, left)),
on,
left_on,
right_on,
kwargs.get("suffixes", ("_x", "_y")),
)

# We rebalance when the ratio of the number of existing partitions to
Expand Down Expand Up @@ -226,7 +251,15 @@ def map_func(left, right): # pragma: no cover
return left.default_to_pandas(pandas.DataFrame.merge, right, **kwargs)

@classmethod
def _compute_result_metadata(cls, left, right, on, left_on, right_on, suffixes):
def _compute_result_metadata(
cls,
left: PandasQueryCompiler,
right: PandasQueryCompiler,
on,
left_on,
right_on,
suffixes,
) -> tuple[Optional[pandas.Index], Optional[ModinDtypes]]:
"""
Compute columns and dtypes metadata for the result of merge if possible.
Expand Down
35 changes: 24 additions & 11 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,33 +526,46 @@ def merge(self, right, **kwargs):
get_logger().info(message)
return MergeImpl.row_axis_merge(self, right, kwargs)

def join(self, right, **kwargs):
def join(self, right: PandasQueryCompiler, **kwargs) -> PandasQueryCompiler:
on = kwargs.get("on", None)
how = kwargs.get("how", "left")
sort = kwargs.get("sort", False)
left = self

if how in ["left", "inner"]:

def map_func(left, right, kwargs=kwargs): # pragma: no cover
return pandas.DataFrame.join(left, right, **kwargs)
if how in ["left", "inner"] or (
how == "right" and right._modin_frame._partitions.size != 0
):
reverted = False
if how == "right":
left, right = right, left
reverted = True

def map_func(
left, right, kwargs=kwargs
) -> pandas.DataFrame: # pragma: no cover
if reverted:
df = pandas.DataFrame.join(right, left, **kwargs)
else:
df = pandas.DataFrame.join(left, right, **kwargs)
return df

right_to_broadcast = right._modin_frame.combine()
new_self = self.__constructor__(
self._modin_frame.broadcast_apply_full_axis(
left = left.__constructor__(
left._modin_frame.broadcast_apply_full_axis(
axis=1,
func=map_func,
# We're going to explicitly change the shape across the 1-axis,
# so we want for partitioning to adapt as well
keep_partitioning=False,
num_splits=merge_partitioning(
self._modin_frame, right._modin_frame, axis=1
left._modin_frame, right._modin_frame, axis=1
),
other=right_to_broadcast,
)
)
return new_self.sort_rows_by_column_values(on) if sort else new_self
return left.sort_rows_by_column_values(on) if sort else left
else:
return self.default_to_pandas(pandas.DataFrame.join, right, **kwargs)
return left.default_to_pandas(pandas.DataFrame.join, right, **kwargs)

# END Inter-Data operations

Expand Down Expand Up @@ -588,7 +601,7 @@ def reindex(self, axis, labels, **kwargs):
)
return self.__constructor__(new_modin_frame)

def reset_index(self, **kwargs):
def reset_index(self, **kwargs) -> PandasQueryCompiler:
if self.lazy_execution:

def _reset(df, *axis_lengths, partition_idx): # pragma: no cover
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1454,9 +1454,14 @@ def _join_by_index(self, other_modin_frames, how, sort, ignore_index):
condition=condition,
)

new_columns = Index.__new__(
Index, data=new_columns, dtype=new_columns_dtype
)
# in the case of heterogeneous data, using the `dtype` parameter of the
# `Index` constructor can lead to the following error:
# `ValueError: string values cannot be losslessly cast to int64`
# that's why we explicitly call astype below
new_columns = Index(new_columns)
if new_columns.dtype != new_columns_dtype and new_columns_dtype is not None:
# ValueError: string values cannot be losslessly cast to int64
new_columns = new_columns.astype(new_columns_dtype)
lhs = lhs.__constructor__(
dtypes=lhs._dtypes_for_exprs(exprs),
columns=new_columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1347,13 +1347,13 @@ def build_row_idx_filter_expr(row_idx, row_col):
return row_col.eq(row_idx)

if is_range_like(row_idx):
start = row_idx[0]
stop = row_idx[-1]
start = row_idx.start
stop = row_idx.stop
step = row_idx.step
if step < 0:
start, stop = stop, start
step = -step
exprs = [row_col.ge(start), row_col.le(stop)]
exprs = [row_col.ge(start), row_col.cmp("<", stop)]
if step > 1:
mod = OpExpr("MOD", [row_col, LiteralExpr(step)], _get_dtype(int))
exprs.append(mod.eq(0))
Expand Down
Loading

0 comments on commit 1c0d9a6

Please sign in to comment.