Skip to content

Commit

Permalink
Merge branch 'main' into test-unification
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty authored Dec 8, 2023
2 parents 4c8abcb + 9acc411 commit 39038f3
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,29 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
):
da.rolling_exp(time=10, keep_attrs=True)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("min_periods", [1, 20])
def test_cumulative(self, da, func, min_periods) -> None:
# One dim
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
expected = getattr(
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)

def test_cumulative_vs_cum(self, da) -> None:
result = da.cumulative("time").sum()
expected = da.cumsum("time")
assert_identical(result, expected)


class TestDatasetRolling:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -834,6 +857,25 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
assert_allclose(actual, expected)

@pytest.mark.parametrize("func", ["mean", "sum"])
@pytest.mark.parametrize("ds", (2,), indirect=True)
@pytest.mark.parametrize("min_periods", [1, 10])
def test_cumulative(self, ds, func, min_periods) -> None:
# One dim
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, min_periods=min_periods), func
)()
assert_identical(result, expected)

# Multiple dim
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
expected = getattr(
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
func,
)()
assert_identical(result, expected)


@requires_numbagg
class TestDatasetRollingExp:
Expand Down

0 comments on commit 39038f3

Please sign in to comment.