Skip to content
Merged
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg by
default, which is up to 5x faster where parallelization is possible. (:pull:`8339`)
By `Maximilian Roos <https://github.com/max-sixty>`_.

.. _whats-new.2023.11.0:

v2023.11.0 (Nov 16, 2023)
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
# DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
# this avoids the need to get involved in zarr synchronization / locking
# From zarr docs:
# "If each worker in a parallel computation is writing to a separate
# region of the array, and if region boundaries are perfectly aligned
# "If each worker in a parallel computation is writing to a
# separate region of the array, and if region boundaries are perfectly aligned
# with chunk boundaries, then no synchronization is required."
# TODO: incorporate synchronizer to allow writes from multiple dask
# threads
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import bottleneck
import dask.array as da
import numpy as np

from xarray.core.duck_array_ops import _push

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
Expand All @@ -85,7 +86,7 @@ def _fill_with_last_one(a, b):

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=bottleneck.push,
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
Expand Down
41 changes: 37 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from numpy import concatenate as _concatenate
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core import dask_array_ops, dtypes, nputils, pycompat
from xarray.core.options import OPTIONS
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.pycompat import array_type, is_duck_dask_array
from xarray.core.utils import is_duck_array, module_available
Expand Down Expand Up @@ -688,13 +690,44 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)


def push(array, n, axis):
from bottleneck import push
def _push(array, n: int | None = None, axis: int = -1):
"""
Use either bottleneck or numbagg depending on options & what's available
"""

if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
raise RuntimeError(
"ffill & bfill requires bottleneck or numbagg to be enabled."
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
)
if OPTIONS["use_numbagg"] and module_available("numbagg"):
import numbagg

if pycompat.mod_version("numbagg") < Version("0.6.2"):
warnings.warn(
f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead."
)
else:
return numbagg.ffill(array, limit=n, axis=axis)

# work around for bottleneck 178
limit = n if n is not None else array.shape[axis]

import bottleneck as bn

return bn.push(array, limit, axis)


def push(array, n, axis):
if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
raise RuntimeError(
"ffill & bfill requires bottleneck or numbagg to be enabled."
" Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
)
if is_duck_dask_array(array):
return dask_array_ops.push(array, n, axis)
else:
return push(array, n, axis)
return _push(array, n, axis)


def _first_last_wrapper(array, *, axis, op, keepdims):
Expand Down
12 changes: 1 addition & 11 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from xarray.core.common import _contains_datetime_like_objects, ones_like
from xarray.core.computation import apply_ufunc
from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.options import _get_keep_attrs
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.types import Interp1dOptions, InterpOptions
from xarray.core.utils import OrderedSet, is_scalar
Expand Down Expand Up @@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1):

def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"ffill requires bottleneck to be enabled."
" Call `xr.set_options(use_bottleneck=True)` to enable it."
)

axis = arr.get_axis_num(dim)

Expand All @@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None):

def bfill(arr, dim=None, limit=None):
"""backfill missing values"""
if not OPTIONS["use_bottleneck"]:
raise RuntimeError(
"bfill requires bottleneck to be enabled."
" Call `xr.set_options(use_bottleneck=True)` to enable it."
)

axis = arr.get_axis_num(dim)

Expand Down
34 changes: 16 additions & 18 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import warnings
from typing import Callable

import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from packaging.version import Version

from xarray.core import pycompat
from xarray.core.utils import module_available

# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
Expand All @@ -25,15 +29,6 @@
bn = np
_BOTTLENECK_AVAILABLE = False

try:
import numbagg

_HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0")
except ImportError:
# use numpy methods instead
numbagg = np # type: ignore
_HAS_NUMBAGG = False


def _select_along_axis(values, idx, axis):
other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
Expand Down Expand Up @@ -171,29 +166,32 @@ def __setitem__(self, key, value):
self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)


def _create_method(name, npmodule=np):
def _create_method(name, npmodule=np) -> Callable:
def f(values, axis=None, **kwargs):
dtype = kwargs.get("dtype", None)
bn_func = getattr(bn, name, None)
nba_func = getattr(numbagg, name, None)

if (
_HAS_NUMBAGG
module_available("numbagg")
and pycompat.mod_version("numbagg") >= Version("0.5.0")
and OPTIONS["use_numbagg"]
and isinstance(values, np.ndarray)
and nba_func is not None
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
# TODO: bool?
and values.dtype.kind in "uifc"
# and values.dtype.isnative
and (dtype is None or np.dtype(dtype) == values.dtype)
):
# numbagg does not take care dtype, ddof
kwargs.pop("dtype", None)
kwargs.pop("ddof", None)
result = nba_func(values, axis=axis, **kwargs)
elif (
import numbagg

nba_func = getattr(numbagg, name, None)
if nba_func is not None:
# numbagg does not take care dtype, ddof
kwargs.pop("dtype", None)
kwargs.pop("ddof", None)
return nba_func(values, axis=axis, **kwargs)
if (
_BOTTLENECK_AVAILABLE
and OPTIONS["use_bottleneck"]
and isinstance(values, np.ndarray)
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
integer_types = (int, np.integer)

if TYPE_CHECKING:
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"]
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic


Expand Down Expand Up @@ -47,6 +47,9 @@ def __init__(self, mod: ModType) -> None:
duck_array_type = (duck_array_module.SparseArray,)
elif mod == "cubed":
duck_array_type = (duck_array_module.Array,)
# Not a duck array module, but using this system regardless, to get lazy imports
elif mod == "numbagg":
duck_array_type = ()
else:
raise NotImplementedError

Expand Down
50 changes: 26 additions & 24 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@
import numpy as np
from packaging.version import Version

from xarray.core import pycompat
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
from xarray.core.types import T_DataWithCoords

try:
import numbagg
from numbagg import move_exp_nanmean, move_exp_nansum

_NUMBAGG_VERSION: Version | None = Version(numbagg.__version__)
except ImportError:
_NUMBAGG_VERSION = None
from xarray.core.utils import module_available


def _get_alpha(
Expand Down Expand Up @@ -83,17 +77,17 @@ def __init__(
window_type: str = "span",
min_weight: float = 0.0,
):
if _NUMBAGG_VERSION is None:
if not module_available("numbagg"):
raise ImportError(
"numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
)
elif _NUMBAGG_VERSION < Version("0.2.1"):
elif pycompat.mod_version("numbagg") < Version("0.2.1"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we boost min supported numbagg version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we basically keep the 12 month cycle — it's some good infra, and I would like to be a good citizen for that process rather than deviate to save myself some boilerplate...

raise ImportError(
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed"
)
elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0:
elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0:
raise ImportError(
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed"
)

self.obj: T_DataWithCoords = obj
Expand Down Expand Up @@ -127,13 +121,15 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

import numbagg

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dim_order = self.obj.dims

return apply_ufunc(
move_exp_nanmean,
numbagg.move_exp_nanmean,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
Expand Down Expand Up @@ -163,13 +159,15 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

import numbagg

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

dim_order = self.obj.dims

return apply_ufunc(
move_exp_nansum,
numbagg.move_exp_nansum,
self.obj,
input_core_dims=[[self.dim]],
kwargs=self.kwargs,
Expand All @@ -194,10 +192,12 @@ def std(self) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
import numbagg

dim_order = self.obj.dims

return apply_ufunc(
Expand Down Expand Up @@ -225,12 +225,12 @@ def var(self) -> T_DataWithCoords:
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nanvar,
Expand Down Expand Up @@ -258,11 +258,12 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nancov,
Expand Down Expand Up @@ -291,11 +292,12 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
Dimensions without coordinates: x
"""

if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"):
if pycompat.mod_version("numbagg") < Version("0.4.0"):
raise ImportError(
f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed"
f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
)
dim_order = self.obj.dims
import numbagg

return apply_ufunc(
numbagg.move_exp_nancorr,
Expand Down
7 changes: 6 additions & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _importorskip(
mod = importlib.import_module(modname)
has = True
if minversion is not None:
if Version(mod.__version__) < Version(minversion):
v = getattr(mod, "__version__", "999")
if Version(v) < Version(minversion):
raise ImportError("Minimum version not satisfied")
except ImportError:
has = False
Expand Down Expand Up @@ -96,6 +97,10 @@ def _importorskip(
requires_scipy_or_netCDF4 = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
requires_numbagg_or_bottleneck = pytest.mark.skipif(
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
)
# _importorskip does not work for development versions
has_pandas_version_two = Version(pd.__version__).major >= 2
requires_pandas_version_two = pytest.mark.skipif(
Expand Down
Loading