Skip to content

Commit 219ef0c

Browse files
authored
Offer a fixture for unifying DataArray & Dataset tests (#8533)
* Add Cumulative aggregation Offer a fixture for unifying `DataArray` & `Dataset` tests (stacked on #8512, worth reviewing after that's merged) Some tests are literally copy & pasted between DataArray & Dataset tests. This change allows them to use a single test. Not everything will work — sometimes we want to check specifics — but sometimes they will...
1 parent 766da34 commit 219ef0c

File tree

2 files changed

+68
-42
lines changed

2 files changed

+68
-42
lines changed

xarray/tests/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -77,3 +79,44 @@ def da(request, backend):
7779
return da
7880
else:
7981
raise ValueError
82+
83+
84+
@pytest.fixture(params=[Dataset, DataArray])
85+
def type(request):
86+
return request.param
87+
88+
89+
@pytest.fixture(params=[1])
90+
def d(request, backend, type) -> DataArray | Dataset:
91+
"""
92+
For tests which can test either a DataArray or a Dataset.
93+
"""
94+
result: DataArray | Dataset
95+
if request.param == 1:
96+
ds = Dataset(
97+
dict(
98+
a=(["x", "z"], np.arange(24).reshape(2, 12)),
99+
b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)),
100+
),
101+
dict(
102+
x=("x", np.linspace(0, 1.0, 2)),
103+
y=range(3),
104+
z=("z", pd.date_range("2000-01-01", periods=12)),
105+
w=("x", ["a", "b"]),
106+
),
107+
)
108+
if type == DataArray:
109+
result = ds["a"].assign_coords(w=ds.coords["w"])
110+
elif type == Dataset:
111+
result = ds
112+
else:
113+
raise ValueError
114+
else:
115+
raise ValueError
116+
117+
if backend == "dask":
118+
return result.chunk()
119+
elif backend == "numpy":
120+
return result
121+
else:
122+
raise ValueError

xarray/tests/test_rolling.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,31 @@ def compute_backend(request):
3636
yield request.param
3737

3838

39+
@pytest.mark.parametrize("func", ["mean", "sum"])
40+
@pytest.mark.parametrize("min_periods", [1, 10])
41+
def test_cumulative(d, func, min_periods) -> None:
42+
# One dim
43+
result = getattr(d.cumulative("z", min_periods=min_periods), func)()
44+
expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)()
45+
assert_identical(result, expected)
46+
47+
# Multiple dim
48+
result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)()
49+
expected = getattr(
50+
d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods),
51+
func,
52+
)()
53+
assert_identical(result, expected)
54+
55+
56+
def test_cumulative_vs_cum(d) -> None:
57+
result = d.cumulative("z").sum()
58+
expected = d.cumsum("z")
59+
# cumsum drops the coord of the dimension; cumulative doesn't
60+
expected = expected.assign_coords(z=result["z"])
61+
assert_identical(result, expected)
62+
63+
3964
class TestDataArrayRolling:
4065
@pytest.mark.parametrize("da", (1, 2), indirect=True)
4166
@pytest.mark.parametrize("center", [True, False])
@@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
485510
):
486511
da.rolling_exp(time=10, keep_attrs=True)
487512

488-
@pytest.mark.parametrize("func", ["mean", "sum"])
489-
@pytest.mark.parametrize("min_periods", [1, 20])
490-
def test_cumulative(self, da, func, min_periods) -> None:
491-
# One dim
492-
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
493-
expected = getattr(
494-
da.rolling(time=da.time.size, min_periods=min_periods), func
495-
)()
496-
assert_identical(result, expected)
497-
498-
# Multiple dim
499-
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
500-
expected = getattr(
501-
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
502-
func,
503-
)()
504-
assert_identical(result, expected)
505-
506-
def test_cumulative_vs_cum(self, da) -> None:
507-
result = da.cumulative("time").sum()
508-
expected = da.cumsum("time")
509-
assert_identical(result, expected)
510-
511513

512514
class TestDatasetRolling:
513515
@pytest.mark.parametrize(
@@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
832834
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
833835
assert_allclose(actual, expected)
834836

835-
@pytest.mark.parametrize("func", ["mean", "sum"])
836-
@pytest.mark.parametrize("ds", (2,), indirect=True)
837-
@pytest.mark.parametrize("min_periods", [1, 10])
838-
def test_cumulative(self, ds, func, min_periods) -> None:
839-
# One dim
840-
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
841-
expected = getattr(
842-
ds.rolling(time=ds.time.size, min_periods=min_periods), func
843-
)()
844-
assert_identical(result, expected)
845-
846-
# Multiple dim
847-
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
848-
expected = getattr(
849-
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
850-
func,
851-
)()
852-
assert_identical(result, expected)
853-
854837

855838
@requires_numbagg
856839
class TestDatasetRollingExp:

0 commit comments

Comments
 (0)