Skip to content

Commit

Permalink
Ensure Coarsen.construct keeps all coords (pydata#7233)
Browse files Browse the repository at this point in the history
* test

* fix

* whatsnew

* group related tests into a class

* Update xarray/core/rolling.py

* Update xarray/core/rolling.py

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
TomNicholas and dcherian authored Oct 28, 2022
1 parent 51d37d1 commit e1936a9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 64 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Bug fixes
now reopens the file from scratch for h5netcdf and scipy netCDF backends,
rather than reusing a cached version (:issue:`4240`, :issue:`4862`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Fixed bug where :py:meth:`Dataset.coarsen.construct` would demote non-dimension coordinates to variables. (:pull:`7233`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Raise a TypeError when trying to plot empty data (:issue:`7156`, :pull:`7228`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Expand Down
5 changes: 4 additions & 1 deletion xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,10 @@ def construct(
else:
reshaped[key] = var

should_be_coords = set(window_dim) & set(self.obj.coords)
# should handle window_dim being unindexed
should_be_coords = (set(window_dim) & set(self.obj.coords)) | set(
self.obj.coords
)
result = reshaped.set_coords(should_be_coords)
if isinstance(self.obj, DataArray):
return self.obj._from_temp_dataset(result)
Expand Down
146 changes: 83 additions & 63 deletions xarray/tests/test_coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,71 +250,91 @@ def test_coarsen_da_reduce(da, window, name) -> None:
assert_allclose(actual, expected)


@pytest.mark.parametrize("dask", [True, False])
def test_coarsen_construct(dask: bool) -> None:

ds = Dataset(
{
"vart": ("time", np.arange(48), {"a": "b"}),
"varx": ("x", np.arange(10), {"a": "b"}),
"vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}),
"vary": ("y", np.arange(12)),
},
coords={"time": np.arange(48), "y": np.arange(12)},
attrs={"foo": "bar"},
)

if dask and has_dask:
ds = ds.chunk({"x": 4, "time": 10})

expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (("year", "month"), ds.vart.data.reshape((-1, 12)), {"a": "b"})
expected["varx"] = (("x", "x_reshaped"), ds.varx.data.reshape((-1, 5)), {"a": "b"})
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
ds.vartx.data.reshape(2, 5, 4, 12),
{"a": "b"},
)
expected["vary"] = ds.vary
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
class TestCoarsenConstruct:
@pytest.mark.parametrize("dask", [True, False])
def test_coarsen_construct(self, dask: bool) -> None:

ds = Dataset(
{
"vart": ("time", np.arange(48), {"a": "b"}),
"varx": ("x", np.arange(10), {"a": "b"}),
"vartx": (("x", "time"), np.arange(480).reshape(10, 48), {"a": "b"}),
"vary": ("y", np.arange(12)),
},
coords={"time": np.arange(48), "y": np.arange(12)},
attrs={"foo": "bar"},
)
assert_identical(actual, expected)

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
time=("year", "month"), x=("x", "x_reshaped")
)
assert_identical(actual, expected)
if dask and has_dask:
ds = ds.chunk({"x": 4, "time": 10})

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False
expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (
("year", "month"),
ds.vart.data.reshape((-1, 12)),
{"a": "b"},
)
for var in actual:
assert actual[var].attrs == {}
assert actual.attrs == {}

with raise_if_dask_computes():
actual = ds.vartx.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
expected["varx"] = (
("x", "x_reshaped"),
ds.varx.data.reshape((-1, 5)),
{"a": "b"},
)
assert_identical(actual, expected["vartx"])

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(foo="bar")

with pytest.raises(ValueError):
ds.coarsen(time=12, x=2).construct(time=("year", "month"))

with pytest.raises(ValueError):
ds.coarsen(time=12).construct()

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(time="bar")

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(time=("bar",))
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
ds.vartx.data.reshape(2, 5, 4, 12),
{"a": "b"},
)
expected["vary"] = ds.vary
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
)
assert_identical(actual, expected)

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
time=("year", "month"), x=("x", "x_reshaped")
)
assert_identical(actual, expected)

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}, keep_attrs=False
)
for var in actual:
assert actual[var].attrs == {}
assert actual.attrs == {}

with raise_if_dask_computes():
actual = ds.vartx.coarsen(time=12, x=5).construct(
{"time": ("year", "month"), "x": ("x", "x_reshaped")}
)
assert_identical(actual, expected["vartx"])

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(foo="bar")

with pytest.raises(ValueError):
ds.coarsen(time=12, x=2).construct(time=("year", "month"))

with pytest.raises(ValueError):
ds.coarsen(time=12).construct()

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(time="bar")

with pytest.raises(ValueError):
ds.coarsen(time=12).construct(time=("bar",))

def test_coarsen_construct_keeps_all_coords(self):
da = xr.DataArray(np.arange(24), dims=["time"])
da = da.assign_coords(day=365 * da)

result = da.coarsen(time=12).construct(time=("year", "month"))
assert list(da.coords) == list(result.coords)

ds = da.to_dataset(name="T")
result = ds.coarsen(time=12).construct(time=("year", "month"))
assert list(da.coords) == list(result.coords)

0 comments on commit e1936a9

Please sign in to comment.