diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 8bf9c68b727..2dca38538e1 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from functools import partial from xarray.core import dtypes, nputils @@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return coeffs, residuals +def _fill_with_last_one(a, b): + import numpy as np + + # cumreduction apply the push func over all the blocks first so, + # the only missing part is filling the missing values using the + # last data of the previous chunk + return np.where(np.isnan(b), a, b) + + +def _dtype_push(a, axis, dtype=None): + from xarray.core.duck_array_ops import _push + + # Not sure why the blelloch algorithm force to receive a dtype + return _push(a, axis=axis) + + +def _reset_cumsum(a, axis, dtype=None): + import numpy as np + + cumsum = np.cumsum(a, axis=axis) + reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) + return cumsum - reset_points + + +def _last_reset_cumsum(a, axis, keepdims=None): + import numpy as np + + # Take the last cumulative sum taking into account the reset + # This is useful for blelloch method + return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) + + +def _combine_reset_cumsum(a, b, axis): + import numpy as np + + # It is going to sum the previous result until the first + # non nan value + bitmask = np.cumprod(b != 0, axis=axis) + return np.where(bitmask, b + a, b) + + def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push @@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"): # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, - # the only missing part is filling the missing values using the - # last data of the previous chunk - return np.where(np.isnan(b), a, b) - - def _dtype_push(a, axis, dtype=None): - # Not sure why the blelloch algorithm force to receive a dtype - return _push(a, axis=axis) - pushed_array = da.reductions.cumreduction( func=_dtype_push, binop=_fill_with_last_one, @@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None): ) if n is not None and 0 < n < array.shape[axis] - 1: - - def _reset_cumsum(a, axis, dtype=None): - cumsum = np.cumsum(a, axis=axis) - reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis) - return cumsum - reset_points - - def _last_reset_cumsum(a, axis, keepdims=None): - # Take the last cumulative sum taking into account the reset - # This is useful for blelloch method - return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1]) - - def _combine_reset_cumsum(a, b): - # It is going to sum the previous result until the first - # non nan value - bitmask = np.cumprod(b != 0, axis=axis) - return np.where(bitmask, b + a, b) - valid_positions = da.reductions.cumreduction( func=_reset_cumsum, - binop=_combine_reset_cumsum, + binop=partial(_combine_reset_cumsum, axis=axis), ident=0, x=da.isnan(array, dtype=int), axis=axis,