Skip to content
forked from pydata/xarray

Commit

Permalink
Avoid local functions in push (pydata#9856)
Browse files Browse the repository at this point in the history
* Avoid local functions in push

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent fab900c commit eac5105
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
from functools import partial

from xarray.core import dtypes, nputils

Expand Down Expand Up @@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return coeffs, residuals


def _fill_with_last_one(a, b):
import numpy as np

# cumreduction apply the push func over all the blocks first so,
# the only missing part is filling the missing values using the
# last data of the previous chunk
return np.where(np.isnan(b), a, b)


def _dtype_push(a, axis, dtype=None):
from xarray.core.duck_array_ops import _push

# Not sure why the blelloch algorithm force to receive a dtype
return _push(a, axis=axis)


def _reset_cumsum(a, axis, dtype=None):
import numpy as np

cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points


def _last_reset_cumsum(a, axis, keepdims=None):
import numpy as np

# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])


def _combine_reset_cumsum(a, b, axis):
import numpy as np

# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)


def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
Expand All @@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"):
# TODO: Replace all this function
# once https://github.com/pydata/xarray/issues/9229 being implemented

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so,
# the only missing part is filling the missing values using the
# last data of the previous chunk
return np.where(np.isnan(b), a, b)

def _dtype_push(a, axis, dtype=None):
# Not sure why the blelloch algorithm force to receive a dtype
return _push(a, axis=axis)

pushed_array = da.reductions.cumreduction(
func=_dtype_push,
binop=_fill_with_last_one,
Expand All @@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None):
)

if n is not None and 0 < n < array.shape[axis] - 1:

def _reset_cumsum(a, axis, dtype=None):
cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points

def _last_reset_cumsum(a, axis, keepdims=None):
# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])

def _combine_reset_cumsum(a, b):
# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)

valid_positions = da.reductions.cumreduction(
func=_reset_cumsum,
binop=_combine_reset_cumsum,
binop=partial(_combine_reset_cumsum, axis=axis),
ident=0,
x=da.isnan(array, dtype=int),
axis=axis,
Expand Down

0 comments on commit eac5105

Please sign in to comment.