Skip to content
forked from pydata/xarray

Commit

Permalink
Use shuffle in groupby binary ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 9, 2024
1 parent eac5105 commit df599a6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 17 deletions.
83 changes: 67 additions & 16 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from xarray.core.alignment import align, broadcast
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.computation import apply_ufunc
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.duck_array_ops import where
Expand Down Expand Up @@ -49,7 +50,7 @@
peek_at,
)
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array
from xarray.util.deprecation_helpers import _deprecate_positional_args

if TYPE_CHECKING:
Expand Down Expand Up @@ -900,25 +901,75 @@ def _binary_op(self, other, f, reflexive=False):
group = group.where(~mask, drop=True)
codes = codes.where(~mask, drop=True).astype(int)

# if other is dask-backed, that's a hint that the
# "expanded" dataset is too big to hold in memory.
# this can be the case when `other` was read from disk
# and contains our lazy indexing classes
# We need to check for dask-backed Datasets
# so utils.is_duck_dask_array does not work for this check
if obj.chunks and not other.chunks:
# TODO: What about datasets with some dask vars, and others not?
# This handles dims other than `name``
chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
# a chunk size of 1 seems reasonable since we expect individual elements of
# other to be repeated multiple times across the reduced dimension(s)
chunks[name] = 1
other = other.chunk(chunks)
def _vindex_wrapper(array, idxr, like):
# we want to use the fact that we know the chunksizes for the output (matches obj)
# so we can't just use Variable's indexing
import dask
from dask.array.core import slices_from_chunks
from dask.graph_manipulation import clone

array = clone(array) # FIXME: add to dask

assert array.ndim == 1
to_shape = like.shape[-1:]
to_chunks = like.chunks[-1:]
flat_indices = [
idxr[slicer].ravel().tolist()
for slicer in slices_from_chunks(to_chunks)
]
# FIXME: figure out axis
shuffled = dask.array.shuffle(
array, flat_indices, axis=array.ndim - 1, chunks="auto"
)
if shuffled.shape != to_shape:
return dask.array.reshape_blockwise(
shuffled, shape=to_shape, chunks=to_chunks
)
else:
return shuffled

# codes are defined for coord, so we align `other` with `coord`
# before indexing
other, _ = align(other, coord, join="right", copy=False)
expanded = other.isel({name: codes})

other_as_dataset = (
other._to_temp_dataset() if isinstance(other, DataArray) else other
)
obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj
dask_vars = []
non_dask_vars = []
for varname, var in other_as_dataset._variables.items():
if is_duck_dask_array(var._data):
dask_vars.append(varname)
else:
non_dask_vars.append(varname)
expanded = other_as_dataset[non_dask_vars].isel({name: codes})
if dask_vars:
other_dims = other_as_dataset[dask_vars].dims
obj_dims = obj_as_dataset[dask_vars].dims
expanded = expanded.merge(
apply_ufunc(
_vindex_wrapper,
other_as_dataset[dask_vars],
codes,
obj_as_dataset[dask_vars],
input_core_dims=[
tuple(other_dims), # FIXME: ..., name
tuple(codes.dims),
tuple(obj_dims),
],
# When other is the result of a reduction over Ellipsis
# obj.dims is a superset of other.dims, and contains
# dims not present in the output
exclude_dims=set(obj_dims) - set(other_dims),
output_core_dims=[tuple(codes.dims)],
dask="allowed",
join=OPTIONS["arithmetic_join"],
)
)

if isinstance(other, DataArray):
expanded = other._from_temp_dataset(expanded)

result = g(obj, expanded)

Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2654,12 +2654,14 @@ def test_groupby_math_auto_chunk() -> None:
dims=("y", "x"),
coords={"label": ("x", [2, 2, 1])},
)
# da.groupby("label").min(...)
sub = xr.DataArray(
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
)
chunked = da.chunk(x=1, y=2)
chunked.label.load()
actual = chunked.groupby("label") - sub
with raise_if_dask_computes():
actual = chunked.groupby("label") - sub
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}


Expand Down

0 comments on commit df599a6

Please sign in to comment.