Skip to content

Commit 8a338ee

Browse files
dcherianmathause
andauthored
Add coarsen.construct (#5476)
Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com>
1 parent c5ae488 commit 8a338ee

File tree

6 files changed

+195
-7
lines changed

6 files changed

+195
-7
lines changed

doc/api-hidden.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
core.rolling.DatasetCoarsen.all
4343
core.rolling.DatasetCoarsen.any
44+
core.rolling.DatasetCoarsen.construct
4445
core.rolling.DatasetCoarsen.count
4546
core.rolling.DatasetCoarsen.max
4647
core.rolling.DatasetCoarsen.mean
@@ -185,6 +186,7 @@
185186

186187
core.rolling.DataArrayCoarsen.all
187188
core.rolling.DataArrayCoarsen.any
189+
core.rolling.DataArrayCoarsen.construct
188190
core.rolling.DataArrayCoarsen.count
189191
core.rolling.DataArrayCoarsen.max
190192
core.rolling.DataArrayCoarsen.mean

doc/howdoi.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ How do I ...
2424
* - change the order of dimensions
2525
- :py:meth:`DataArray.transpose`, :py:meth:`Dataset.transpose`
2626
* - reshape dimensions
27-
- :py:meth:`DataArray.stack`, :py:meth:`Dataset.stack`
27+
- :py:meth:`DataArray.stack`, :py:meth:`Dataset.stack`, :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct`
2828
* - remove a variable from my object
2929
- :py:meth:`Dataset.drop_vars`, :py:meth:`DataArray.drop_vars`
3030
* - remove dimensions of length 1 or 0

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ v0.18.3 (unreleased)
2222
New Features
2323
~~~~~~~~~~~~
2424

25+
- Added :py:meth:`Dataset.coarsen.construct`, :py:meth:`DataArray.coarsen.construct` (:issue:`5454`, :pull:`5475`).
26+
By `Deepak Cherian <https://github.com/dcherian>`_.
2527
- Xarray now uses consolidated metadata by default when writing and reading Zarr
2628
stores (:issue:`5251`).
2729
By `Stephan Hoyer <https://github.com/shoyer>`_.

xarray/core/rolling.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import itertools
23
import warnings
34
from typing import Any, Callable, Dict
45

@@ -8,6 +9,7 @@
89
from .arithmetic import CoarsenArithmetic
910
from .options import _get_keep_attrs
1011
from .pycompat import is_duck_dask_array
12+
from .utils import either_dict_or_kwargs
1113

1214
try:
1315
import bottleneck
@@ -845,6 +847,109 @@ def __repr__(self):
845847
klass=self.__class__.__name__, attrs=",".join(attrs)
846848
)
847849

850+
def construct(
851+
self,
852+
window_dim=None,
853+
keep_attrs=None,
854+
**window_dim_kwargs,
855+
):
856+
"""
857+
Convert this Coarsen object to a DataArray or Dataset,
858+
where the coarsening dimension is split or reshaped to two
859+
new dimensions.
860+
861+
Parameters
862+
----------
863+
window_dim: mapping
864+
A mapping from existing dimension name to new dimension names.
865+
The size of the second dimension will be the length of the
866+
coarsening window.
867+
keep_attrs: bool, optional
868+
Preserve attributes if True
869+
**window_dim_kwargs : {dim: new_name, ...}
870+
The keyword arguments form of ``window_dim``.
871+
872+
Returns
873+
-------
874+
Dataset or DataArray with reshaped dimensions
875+
876+
Examples
877+
--------
878+
>>> da = xr.DataArray(np.arange(24), dims="time")
879+
>>> da.coarsen(time=12).construct(time=("year", "month"))
880+
<xarray.DataArray (year: 2, month: 12)>
881+
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
882+
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
883+
Dimensions without coordinates: year, month
884+
885+
See Also
886+
--------
887+
DataArrayRolling.construct
888+
DatasetRolling.construct
889+
"""
890+
891+
from .dataarray import DataArray
892+
from .dataset import Dataset
893+
894+
window_dim = either_dict_or_kwargs(
895+
window_dim, window_dim_kwargs, "Coarsen.construct"
896+
)
897+
if not window_dim:
898+
raise ValueError(
899+
"Either window_dim or window_dim_kwargs need to be specified."
900+
)
901+
902+
bad_new_dims = tuple(
903+
win
904+
for win, dims in window_dim.items()
905+
if len(dims) != 2 or isinstance(dims, str)
906+
)
907+
if bad_new_dims:
908+
raise ValueError(
909+
f"Please provide exactly two dimension names for the following coarsening dimensions: {bad_new_dims}"
910+
)
911+
912+
if keep_attrs is None:
913+
keep_attrs = _get_keep_attrs(default=True)
914+
915+
missing_dims = set(window_dim) - set(self.windows)
916+
if missing_dims:
917+
raise ValueError(
918+
f"'window_dim' must contain entries for all dimensions to coarsen. Missing {missing_dims}"
919+
)
920+
extra_windows = set(self.windows) - set(window_dim)
921+
if extra_windows:
922+
raise ValueError(
923+
f"'window_dim' includes dimensions that will not be coarsened: {extra_windows}"
924+
)
925+
926+
reshaped = Dataset()
927+
if isinstance(self.obj, DataArray):
928+
obj = self.obj._to_temp_dataset()
929+
else:
930+
obj = self.obj
931+
932+
reshaped.attrs = obj.attrs if keep_attrs else {}
933+
934+
for key, var in obj.variables.items():
935+
reshaped_dims = tuple(
936+
itertools.chain(*[window_dim.get(dim, [dim]) for dim in list(var.dims)])
937+
)
938+
if reshaped_dims != var.dims:
939+
windows = {w: self.windows[w] for w in window_dim if w in var.dims}
940+
reshaped_var, _ = var.coarsen_reshape(windows, self.boundary, self.side)
941+
attrs = var.attrs if keep_attrs else {}
942+
reshaped[key] = (reshaped_dims, reshaped_var, attrs)
943+
else:
944+
reshaped[key] = var
945+
946+
should_be_coords = set(window_dim) & set(self.obj.coords)
947+
result = reshaped.set_coords(should_be_coords)
948+
if isinstance(self.obj, DataArray):
949+
return self.obj._from_temp_dataset(result)
950+
else:
951+
return result
952+
848953

849954
class DataArrayCoarsen(Coarsen):
850955
__slots__ = ()

xarray/core/variable.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,7 +2158,7 @@ def coarsen(
21582158
if not windows:
21592159
return self._replace(attrs=_attrs)
21602160

2161-
reshaped, axes = self._coarsen_reshape(windows, boundary, side)
2161+
reshaped, axes = self.coarsen_reshape(windows, boundary, side)
21622162
if isinstance(func, str):
21632163
name = func
21642164
func = getattr(duck_array_ops, name, None)
@@ -2167,7 +2167,7 @@ def coarsen(
21672167

21682168
return self._replace(data=func(reshaped, axis=axes, **kwargs), attrs=_attrs)
21692169

2170-
def _coarsen_reshape(self, windows, boundary, side):
2170+
def coarsen_reshape(self, windows, boundary, side):
21712171
"""
21722172
Construct a reshaped-array for coarsen
21732173
"""
@@ -2183,7 +2183,9 @@ def _coarsen_reshape(self, windows, boundary, side):
21832183

21842184
for d, window in windows.items():
21852185
if window <= 0:
2186-
raise ValueError(f"window must be > 0. Given {window}")
2186+
raise ValueError(
2187+
f"window must be > 0. Given {window} for dimension {d}"
2188+
)
21872189

21882190
variable = self
21892191
for d, window in windows.items():
@@ -2193,8 +2195,8 @@ def _coarsen_reshape(self, windows, boundary, side):
21932195
if boundary[d] == "exact":
21942196
if n * window != size:
21952197
raise ValueError(
2196-
"Could not coarsen a dimension of size {} with "
2197-
"window {}".format(size, window)
2198+
f"Could not coarsen a dimension of size {size} with "
2199+
f"window {window} and boundary='exact'. Try a different 'boundary' option."
21982200
)
21992201
elif boundary[d] == "trim":
22002202
if side[d] == "left":

xarray/tests/test_coarsen.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
import xarray as xr
66
from xarray import DataArray, Dataset, set_options
77

8-
from . import assert_allclose, assert_equal, has_dask, requires_cftime
8+
from . import (
9+
assert_allclose,
10+
assert_equal,
11+
assert_identical,
12+
has_dask,
13+
raise_if_dask_computes,
14+
requires_cftime,
15+
)
916
from .test_dataarray import da
1017
from .test_dataset import ds
1118

@@ -299,3 +306,73 @@ def test_coarsen_da_reduce(da, window, name):
299306
actual = coarsen_obj.reduce(getattr(np, f"nan{name}"))
300307
expected = getattr(coarsen_obj, name)()
301308
assert_allclose(actual, expected)
309+
310+
311+
@pytest.mark.parametrize("dask", [True, False])
312+
def test_coarsen_construct(dask):
313+
314+
ds = Dataset(
315+
{
316+
"vart": ("time", np.arange(48), {"a": "b"}),
317+
"varx": ("x", np.arange(10), {"a": "b"}),
318+
"vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}),
319+
"vary": ("y", np.arange(12)),
320+
},
321+
coords={"time": np.arange(48), "y": np.arange(12)},
322+
attrs={"foo": "bar"},
323+
)
324+
325+
if dask and has_dask:
326+
ds = ds.chunk({"x": 4, "time": 10})
327+
328+
expected = xr.Dataset(attrs={"foo": "bar"})
329+
expected["vart"] = (("year", "month"), ds.vart.data.reshape((-1, 12)), {"a": "b"})
330+
expected["varx"] = (("x", "x_reshaped"), ds.varx.data.reshape((-1, 5)), {"a": "b"})
331+
expected["vartx"] = (
332+
("x", "x_reshaped", "year", "month"),
333+
ds.vartx.data.reshape(2, 5, 4, 12),
334+
{"a": "b"},
335+
)
336+
expected["vary"] = ds.vary
337+
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
338+
339+
with raise_if_dask_computes():
340+
actual = ds.coarsen(time=12, x=5).construct(
341+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
342+
)
343+
assert_identical(actual, expected)
344+
345+
with raise_if_dask_computes():
346+
actual = ds.coarsen(time=12, x=5).construct(
347+
time=("year", "month"), x=("x", "x_reshaped")
348+
)
349+
assert_identical(actual, expected)
350+
351+
with raise_if_dask_computes():
352+
actual = ds.coarsen(time=12, x=5).construct(
353+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False
354+
)
355+
for var in actual:
356+
assert actual[var].attrs == {}
357+
assert actual.attrs == {}
358+
359+
with raise_if_dask_computes():
360+
actual = ds.vartx.coarsen(time=12, x=5).construct(
361+
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
362+
)
363+
assert_identical(actual, expected["vartx"])
364+
365+
with pytest.raises(ValueError):
366+
ds.coarsen(time=12).construct(foo="bar")
367+
368+
with pytest.raises(ValueError):
369+
ds.coarsen(time=12, x=2).construct(time=("year", "month"))
370+
371+
with pytest.raises(ValueError):
372+
ds.coarsen(time=12).construct()
373+
374+
with pytest.raises(ValueError):
375+
ds.coarsen(time=12).construct(time="bar")
376+
377+
with pytest.raises(ValueError):
378+
ds.coarsen(time=12).construct(time=("bar",))

0 commit comments

Comments
 (0)