Skip to content

Optimize ffill, bfill with dask when limit is specified #9771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ New Features
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Optimize ffill, bfill with dask when limit is specified
(:pull:`9771`).
By `Joseph Nowak <https://github.com/josephnowak>`_, and
`Patrick Hoefler <https://github.com/phofl>`.
- Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`).
By `Sam Levang <https://github.com/slevang>`_ and `Tien Vo <https://github.com/tien-vo>`_.
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with
Expand Down
74 changes: 52 additions & 22 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,41 +75,71 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return coeffs, residuals


def push(array, n, axis):
def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np

from xarray.core.duck_array_ops import _push
from xarray.core.nputils import nanlast

if n is not None and all(n <= size for size in array.chunks[axis]):
return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis)

# TODO: Replace all this function
# once https://github.com/pydata/xarray/issues/9229 being implemented

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)
# cumreduction apply the push func over all the blocks first so,
# the only missing part is filling the missing values using the
# last data of the previous chunk
return np.where(np.isnan(b), a, b)

if n is not None and 0 < n < array.shape[axis] - 1:
arange = da.broadcast_to(
da.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
).reshape(
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
),
array.shape,
array.chunks,
)
valid_arange = da.where(da.notnull(array), arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return da.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=_push,
def _dtype_push(a, axis, dtype=None):
# Not sure why the blelloch algorithm force to receive a dtype
return _push(a, axis=axis)

pushed_array = da.reductions.cumreduction(
func=_dtype_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
method=method,
preop=nanlast,
)

if n is not None and 0 < n < array.shape[axis] - 1:

def _reset_cumsum(a, axis, dtype=None):
cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points

def _last_reset_cumsum(a, axis, keepdims=None):
# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])

def _combine_reset_cumsum(a, b):
# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)

valid_positions = da.reductions.cumreduction(
func=_reset_cumsum,
binop=_combine_reset_cumsum,
ident=0,
x=da.isnan(array, dtype=int),
axis=axis,
dtype=int,
method=method,
preop=_last_reset_cumsum,
)
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)

return pushed_array
6 changes: 4 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ def first(values, axis, skipna=None):
return chunked_nanfirst(values, axis)
else:
return nputils.nanfirst(values, axis)

return take(values, 0, axis=axis)


Expand All @@ -729,6 +730,7 @@ def last(values, axis, skipna=None):
return chunked_nanlast(values, axis)
else:
return nputils.nanlast(values, axis)

return take(values, -1, axis=axis)


Expand Down Expand Up @@ -769,14 +771,14 @@ def _push(array, n: int | None = None, axis: int = -1):
return bn.push(array, limit, axis)


def push(array, n, axis):
def push(array, n, axis, method="blelloch"):
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
raise RuntimeError(
"ffill & bfill requires bottleneck or numbagg to be enabled."
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
)
if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
return dask_array_ops.push(array, n, axis, method=method)
else:
return _push(array, n, axis)

Expand Down
12 changes: 9 additions & 3 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,8 @@ def test_least_squares(use_dask, skipna):

@requires_dask
@requires_bottleneck
def test_push_dask():
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
def test_push_dask(method):
import bottleneck
import dask.array

Expand All @@ -1018,13 +1019,18 @@ def test_push_dask():
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
with raise_if_dask_computes():
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n)
actual = push(
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
)
np.testing.assert_equal(actual, expected)

# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
axis=0,
n=n,
method=method,
)
np.testing.assert_equal(actual, expected)

Expand Down
Loading