diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9596d19e735..1bf78a13bf8 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -20,6 +20,7 @@ from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce +from xarray.core.computation import apply_ufunc from xarray.core.concat import concat from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.duck_array_ops import where @@ -49,7 +50,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.pycompat import is_chunked_array, is_duck_dask_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -900,25 +901,75 @@ def _binary_op(self, other, f, reflexive=False): group = group.where(~mask, drop=True) codes = codes.where(~mask, drop=True).astype(int) - # if other is dask-backed, that's a hint that the - # "expanded" dataset is too big to hold in memory. - # this can be the case when `other` was read from disk - # and contains our lazy indexing classes - # We need to check for dask-backed Datasets - # so utils.is_duck_dask_array does not work for this check - if obj.chunks and not other.chunks: - # TODO: What about datasets with some dask vars, and others not? - # This handles dims other than `name`` - chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims} - # a chunk size of 1 seems reasonable since we expect individual elements of - # other to be repeated multiple times across the reduced dimension(s) - chunks[name] = 1 - other = other.chunk(chunks) + def _vindex_wrapper(array, idxr, like): + # we want to use the fact that we know the chunksizes for the output (matches obj) + # so we can't just use Variable's indexing + import dask + from dask.array.core import slices_from_chunks + from dask.graph_manipulation import clone + + array = clone(array) # FIXME: add to dask + + assert array.ndim == 1 + to_shape = like.shape[-1:] + to_chunks = like.chunks[-1:] + flat_indices = [ + idxr[slicer].ravel().tolist() + for slicer in slices_from_chunks(to_chunks) + ] + # FIXME: figure out axis + shuffled = dask.array.shuffle( + array, flat_indices, axis=array.ndim - 1, chunks="auto" + ) + if shuffled.shape != to_shape: + return dask.array.reshape_blockwise( + shuffled, shape=to_shape, chunks=to_chunks + ) + else: + return shuffled # codes are defined for coord, so we align `other` with `coord` # before indexing other, _ = align(other, coord, join="right", copy=False) - expanded = other.isel({name: codes}) + + other_as_dataset = ( + other._to_temp_dataset() if isinstance(other, DataArray) else other + ) + obj_as_dataset = obj._to_temp_dataset() if isinstance(obj, DataArray) else obj + dask_vars = [] + non_dask_vars = [] + for varname, var in other_as_dataset._variables.items(): + if is_duck_dask_array(var._data): + dask_vars.append(varname) + else: + non_dask_vars.append(varname) + expanded = other_as_dataset[non_dask_vars].isel({name: codes}) + if dask_vars: + other_dims = other_as_dataset[dask_vars].dims + obj_dims = obj_as_dataset[dask_vars].dims + expanded = expanded.merge( + apply_ufunc( + _vindex_wrapper, + other_as_dataset[dask_vars], + codes, + obj_as_dataset[dask_vars], + input_core_dims=[ + tuple(other_dims), # FIXME: ..., name + tuple(codes.dims), + tuple(obj_dims), + ], + # When other is the result of a reduction over Ellipsis + # obj.dims is a superset of other.dims, and contains + # dims not present in the output + exclude_dims=set(obj_dims) - set(other_dims), + output_core_dims=[tuple(codes.dims)], + dask="allowed", + join=OPTIONS["arithmetic_join"], + ) + ) + + if isinstance(other, DataArray): + expanded = other._from_temp_dataset(expanded) result = g(obj, expanded) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e4383dd58a9..272f012564b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2654,12 +2654,14 @@ def test_groupby_math_auto_chunk() -> None: dims=("y", "x"), coords={"label": ("x", [2, 2, 1])}, ) + # da.groupby("label").min(...) sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) chunked = da.chunk(x=1, y=2) chunked.label.load() - actual = chunked.groupby("label") - sub + with raise_if_dask_computes(): + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}