Skip to content

Commit d9d4098

Browse files
committed
slight improvements to tests
1 parent 1ed4789 commit d9d4098

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

xarray/tests/test_computation.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,22 +1390,26 @@ def test_corr_only_dataarray() -> None:
13901390
xr.corr(xr.Dataset(), xr.Dataset()) # type: ignore[type-var]
13911391

13921392

1393-
def arrays_w_tuples():
1393+
@pytest.fixture(scope="module")
1394+
def arrays():
13941395
da = xr.DataArray(
13951396
np.random.random((3, 21, 4)),
13961397
coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)},
13971398
dims=("a", "time", "x"),
13981399
)
13991400

1400-
arrays = [
1401+
return [
14011402
da.isel(time=range(0, 18)),
14021403
da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(),
14031404
xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]),
14041405
xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]),
14051406
xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]),
14061407
]
14071408

1408-
array_tuples = [
1409+
1410+
@pytest.fixture(scope="module")
1411+
def array_tuples(arrays):
1412+
return [
14091413
(arrays[0], arrays[0]),
14101414
(arrays[0], arrays[1]),
14111415
(arrays[1], arrays[1]),
@@ -1417,27 +1421,19 @@ def arrays_w_tuples():
14171421
(arrays[4], arrays[4]),
14181422
]
14191423

1420-
return arrays, array_tuples
1421-
14221424

14231425
@pytest.mark.parametrize("ddof", [0, 1])
1424-
@pytest.mark.parametrize(
1425-
"da_a, da_b",
1426-
[
1427-
arrays_w_tuples()[1][3],
1428-
arrays_w_tuples()[1][4],
1429-
arrays_w_tuples()[1][5],
1430-
arrays_w_tuples()[1][6],
1431-
arrays_w_tuples()[1][7],
1432-
arrays_w_tuples()[1][8],
1433-
],
1434-
)
1426+
@pytest.mark.parametrize("n", [3, 4, 5, 6, 7, 8])
14351427
@pytest.mark.parametrize("dim", [None, "x", "time"])
14361428
@requires_dask
1437-
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
1429+
def test_lazy_corrcov(
1430+
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
1431+
) -> None:
14381432
# GH 5284
14391433
from dask import is_dask_collection
14401434

1435+
da_a, da_b = array_tuples[n]
1436+
14411437
with raise_if_dask_computes():
14421438
cov = xr.cov(da_a.chunk(), da_b.chunk(), dim=dim, ddof=ddof)
14431439
assert is_dask_collection(cov)
@@ -1447,12 +1443,13 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
14471443

14481444

14491445
@pytest.mark.parametrize("ddof", [0, 1])
1450-
@pytest.mark.parametrize(
1451-
"da_a, da_b",
1452-
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
1453-
)
1446+
@pytest.mark.parametrize("n", [0, 1, 2])
14541447
@pytest.mark.parametrize("dim", [None, "time"])
1455-
def test_cov(da_a, da_b, dim, ddof) -> None:
1448+
def test_cov(
1449+
n: int, dim: str | None, ddof: int, array_tuples: tuple[xr.DataArray, xr.DataArray]
1450+
) -> None:
1451+
da_a, da_b = array_tuples[n]
1452+
14561453
if dim is not None:
14571454

14581455
def np_cov_ind(ts1, ts2, a, x):
@@ -1499,12 +1496,13 @@ def np_cov(ts1, ts2):
14991496
assert_allclose(actual, expected)
15001497

15011498

1502-
@pytest.mark.parametrize(
1503-
"da_a, da_b",
1504-
[arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]],
1505-
)
1499+
@pytest.mark.parametrize("n", [0, 1, 2])
15061500
@pytest.mark.parametrize("dim", [None, "time"])
1507-
def test_corr(da_a, da_b, dim) -> None:
1501+
def test_corr(
1502+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1503+
) -> None:
1504+
da_a, da_b = array_tuples[n]
1505+
15081506
if dim is not None:
15091507

15101508
def np_corr_ind(ts1, ts2, a, x):
@@ -1547,12 +1545,12 @@ def np_corr(ts1, ts2):
15471545
assert_allclose(actual, expected)
15481546

15491547

1550-
@pytest.mark.parametrize(
1551-
"da_a, da_b",
1552-
arrays_w_tuples()[1],
1553-
)
1548+
@pytest.mark.parametrize("n", range(9))
15541549
@pytest.mark.parametrize("dim", [None, "time", "x"])
1555-
def test_covcorr_consistency(da_a, da_b, dim) -> None:
1550+
def test_covcorr_consistency(
1551+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1552+
) -> None:
1553+
da_a, da_b = array_tuples[n]
15561554
# Testing that xr.corr and xr.cov are consistent with each other
15571555
# 1. Broadcast the two arrays
15581556
da_a, da_b = broadcast(da_a, da_b)
@@ -1569,10 +1567,13 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
15691567

15701568

15711569
@requires_dask
1572-
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
1570+
@pytest.mark.parametrize("n", range(9))
15731571
@pytest.mark.parametrize("dim", [None, "time", "x"])
15741572
@pytest.mark.filterwarnings("ignore:invalid value encountered in .*divide")
1575-
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
1573+
def test_corr_lazycorr_consistency(
1574+
n: int, dim: str | None, array_tuples: tuple[xr.DataArray, xr.DataArray]
1575+
) -> None:
1576+
da_a, da_b = array_tuples[n]
15761577
da_al = da_a.chunk()
15771578
da_bl = da_b.chunk()
15781579
c_abl = xr.corr(da_al, da_bl, dim=dim)
@@ -1591,19 +1592,18 @@ def test_corr_dtype_error():
15911592
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))
15921593

15931594

1594-
@pytest.mark.parametrize(
1595-
"da_a",
1596-
arrays_w_tuples()[0],
1597-
)
1595+
@pytest.mark.parametrize("n", range(5))
15981596
@pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]])
1599-
def test_autocov(da_a, dim) -> None:
1597+
def test_autocov(n: int, dim: str | None, arrays) -> None:
1598+
da = arrays[n]
1599+
16001600
# Testing that the autocovariance*(N-1) is ~=~ to the variance matrix
16011601
# 1. Ignore the nans
1602-
valid_values = da_a.notnull()
1602+
valid_values = da.notnull()
16031603
# Because we're using ddof=1, this requires > 1 value in each sample
1604-
da_a = da_a.where(valid_values.sum(dim=dim) > 1)
1605-
expected = ((da_a - da_a.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
1606-
actual = xr.cov(da_a, da_a, dim=dim) * (valid_values.sum(dim) - 1)
1604+
da = da.where(valid_values.sum(dim=dim) > 1)
1605+
expected = ((da - da.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1)
1606+
actual = xr.cov(da, da, dim=dim) * (valid_values.sum(dim) - 1)
16071607
assert_allclose(actual, expected)
16081608

16091609

0 commit comments

Comments
 (0)