Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nd-rolling #4219

Merged
merged 27 commits into from
Aug 8, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,16 @@ a value when aggregating:
r = arr.rolling(y=3, center=True, min_periods=2)
r.mean()
From version 0.17, xarray supports multidimensional rolling,

.. ipython:: python
r = arr.rolling(x=2, y=3, min_periods=2)
r.mean()
.. tip::

Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects.
Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling.

.. _bottleneck: https://github.com/pydata/bottleneck/

Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
now accept more than 1 dimension.(:pull:`4219`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Build :py:meth:`CFTimeIndex.__repr__` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new
property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in
:py:meth:`CFTimeIndex.__repr__` (:issue:`2416`, :pull:`4092`)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def rolling(
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
center : boolean, or a mapping, default False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Types on the signature should change too (indeed, I'm surprised mypy doesn't fail)

Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
Expand Down
8 changes: 8 additions & 0 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def rolling_window(a, axis, window, center, fill_value):
"""
import dask.array as da

# for nd-rolling.
# TODO It can be more efficient. Currently, the chunks at the boundaries
# will be copied, but it might be OK for many-chunked-arrays.
if hasattr(axis, "__len__"):
for ax, win, cen in zip(axis, window, center):
a = rolling_window(a, ax, win, cen, fill_value)
return a
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it is true, though I didn't check it.
A suspicious part is da.concatenate, which I expect does not copy the original (strided-)array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that ghosting breaks my expectation. I need to update the algo here.


orig_shape = a.shape
if axis < 0:
axis = a.ndim + axis
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5969,7 +5969,7 @@ def polyfit(
skipna_da = np.any(da.isnull())

dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
stacked_coords = {}
stacked_coords: Dict[Hashable, DataArray] = {}
if dims_to_stack:
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
rhs = da.transpose(dim, *dims_to_stack).stack(
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,22 @@ def __setitem__(self, key, value):
def rolling_window(a, axis, window, center, fill_value):
""" rolling window with padding. """
pads = [(0, 0) for s in a.shape]
if center:
start = int(window / 2) # 10 -> 5, 9 -> 4
end = window - 1 - start
pads[axis] = (start, end)
else:
pads[axis] = (window - 1, 0)
if not hasattr(axis, "__len__"):
axis = [axis]
window = [window]
center = [center]

for ax, win, cent in zip(axis, window, center):
if cent:
start = int(win / 2) # 10 -> 5, 9 -> 4
end = win - 1 - start
pads[ax] = (start, end)
else:
pads[ax] = (win - 1, 0)
a = np.pad(a, pads, mode="constant", constant_values=fill_value)
return _rolling_window(a, window, axis)
for ax, win in zip(axis, window):
a = _rolling_window(a, win, ax)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clever! I spent a while trying to figure out how it works...

return a


def _rolling_window(a, window, axis=-1):
Expand Down
Loading