Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize cumulative reduction (scan) to non-dask types #8019

Merged
merged 12 commits into from
Dec 18, 2023
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.ffill` to
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.

.. _whats-new.2023.07.0:

Expand Down
39 changes: 0 additions & 39 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,42 +53,3 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
# See issue dask/dask#6516
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
return coeffs, residuals


def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import bottleneck
import dask.array as da
import numpy as np

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)

if n is not None and 0 < n < array.shape[axis] - 1:
arange = da.broadcast_to(
da.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
).reshape(
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
),
array.shape,
array.chunks,
)
valid_arange = da.where(da.notnull(array), arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return da.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=bottleneck.push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
20 changes: 20 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
):
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
)

def apply_gufunc(
self,
func: Callable,
Expand Down
47 changes: 45 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,54 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)


def _chunked_push(array, n, axis):
"""
Chunk-aware bottleneck.push
"""
import bottleneck
import numpy as np

chunkmanager = get_chunked_array_type(array)
xp = chunkmanager.array_api

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)

if n is not None and 0 < n < array.shape[axis] - 1:
arange = xp.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
)
broadcasted_arange = xp.broadcast_to(
xp.reshape(
arange,
tuple(size if i == axis else 1 for i, size in enumerate(array.shape)),
),
array.shape,
array.chunks,
)
valid_arange = xp.where(xp.notnull(array), broadcasted_arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return xp.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
return chunkmanager.scan(
func=bottleneck.push,
binop=_fill_with_last_one,
ident=np.nan,
arr=array,
axis=axis,
dtype=array.dtype,
)


def push(array, n, axis):
from bottleneck import push

if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
if is_chunked_array(array):
return _chunked_push(array, n, axis)
else:
return push(array, n, axis)

Expand Down
36 changes: 36 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,42 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
):
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
"""
General version of a 1D scan, also known as a cumulative array reduction.

Used in ``ffill`` and ``bfill` in xarray.

Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype

Returns
-------
Chunked array

See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down
Loading