Skip to content

move da and ds fixtures to conftest.py #6730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ ignore =
E501 # line too long - let black worry about that
E731 # do not assign a lambda expression, use a def
W503 # line break before binary operator
per-file-ignores =
xarray/tests/*.py:F401,F811
exclude=
.eggs
doc
Expand Down
74 changes: 73 additions & 1 deletion xarray/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,80 @@
import numpy as np
import pandas as pd
import pytest

from . import requires_dask
from xarray import DataArray, Dataset

from . import create_test_data, requires_dask


@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)])
def backend(request):
return request.param


@pytest.fixture(params=[1])
def ds(request, backend):
if request.param == 1:
ds = Dataset(
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)
elif request.param == 2:
ds = Dataset(
dict(
z1=(["time", "y"], np.random.randn(10, 2)),
z2=(["time"], np.random.randn(10)),
z3=(["x", "time"], np.random.randn(8, 10)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)
elif request.param == 3:
ds = create_test_data()
else:
raise ValueError

if backend == "dask":
return ds.chunk()

return ds


@pytest.fixture(params=[1])
def da(request, backend):
if request.param == 1:
times = pd.date_range("2000-01-01", freq="1D", periods=21)
da = DataArray(
np.random.random((3, 21, 4)),
dims=("a", "time", "x"),
coords=dict(time=times),
)

if request.param == 2:
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

if request.param == "repeating_ints":
da = DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={"x": list("abc"), "y": list("defg")},
dims=list("zyx"),
)

if backend == "dask":
return da.chunk()
elif backend == "numpy":
return da
else:
raise ValueError
2 changes: 0 additions & 2 deletions xarray/tests/test_coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
raise_if_dask_computes,
requires_cftime,
)
from .test_dataarray import da
from .test_dataset import ds


def test_coarsen_absent_dims_error(ds: Dataset) -> None:
Expand Down
11 changes: 2 additions & 9 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,8 @@
unified_dim_sizes,
)
from xarray.core.pycompat import dask_version
from xarray.core.types import T_Xarray

from . import (
has_cftime,
has_dask,
raise_if_dask_computes,
requires_cftime,
requires_dask,
)

from . import has_dask, raise_if_dask_computes, requires_cftime, requires_dask


def assert_identical(a, b):
Expand Down
28 changes: 0 additions & 28 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5859,34 +5859,6 @@ def test_idxminmax_dask(self, op, ndim) -> None:
assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x"))


@pytest.fixture(params=[1])
def da(request, backend):
if request.param == 1:
times = pd.date_range("2000-01-01", freq="1D", periods=21)
da = DataArray(
np.random.random((3, 21, 4)),
dims=("a", "time", "x"),
coords=dict(time=times),
)

if request.param == 2:
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")

if request.param == "repeating_ints":
da = DataArray(
np.tile(np.arange(12), 5).reshape(5, 4, 3),
coords={"x": list("abc"), "y": list("defg")},
dims=list("zyx"),
)

if backend == "dask":
return da.chunk()
elif backend == "numpy":
return da
else:
raise ValueError


@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True)
def test_isin(da) -> None:
expected = DataArray(
Expand Down
40 changes: 0 additions & 40 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6151,46 +6151,6 @@ def test_dir_unicode(ds) -> None:
assert "unicode" in result


@pytest.fixture(params=[1])
def ds(request, backend):
if request.param == 1:
ds = Dataset(
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)
elif request.param == 2:
ds = Dataset(
dict(
z1=(["time", "y"], np.random.randn(10, 2)),
z2=(["time"], np.random.randn(10)),
z3=(["x", "time"], np.random.randn(8, 10)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)
elif request.param == 3:
ds = create_test_data()
else:
raise ValueError

if backend == "dask":
return ds.chunk()

return ds


@pytest.mark.parametrize(
"funcname, argument",
[
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from xarray.core.variable import IndexVariable, Variable

from . import assert_equal, assert_identical
from . import assert_identical


def test_asarray_tuplesafe() -> None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import assert_array_equal
from . import assert_identical as assert_identical_
from . import assert_no_warnings, mock
from . import mock


def assert_identical(a, b):
Expand Down