Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize cumulative reduction (scan) to non-dask types #8019

Merged
merged 12 commits into from
Dec 18, 2023
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll <https://github.com/itcarroll>`_.
Expand Down
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.
Used in ``ffill`` and ``bfill`` in xarray.
Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype
Returns
-------
Chunked array
See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down
Loading