Skip to content

Use backend in ds fixture #5411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6078,11 +6078,6 @@ def test_query(self, backend, engine, parser):
# pytest tests — new tests should go here, rather than in the class.


@pytest.fixture(params=[None])
def data_set(request):
return create_test_data(request.param)


@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2])))
def test_isin(test_elements, backend):
expected = Dataset(
Expand Down Expand Up @@ -6153,17 +6148,18 @@ def test_constructor_raises_with_invalid_coords(unaligned_coords):
xr.DataArray([1, 2, 3], dims=["x"], coords=unaligned_coords)


def test_dir_expected_attrs(data_set):
@pytest.mark.parametrize("ds", [3], indirect=True)
def test_dir_expected_attrs(ds):

some_expected_attrs = {"pipe", "mean", "isnull", "var1", "dim2", "numbers"}
result = dir(data_set)
result = dir(ds)
assert set(result) >= some_expected_attrs


def test_dir_non_string(data_set):
def test_dir_non_string(ds):
# add a numbered key to ensure this doesn't break dir
data_set[5] = "foo"
result = dir(data_set)
ds[5] = "foo"
result = dir(ds)
assert 5 not in result

# GH2172
Expand All @@ -6173,16 +6169,16 @@ def test_dir_non_string(data_set):
dir(x2)


def test_dir_unicode(data_set):
data_set["unicode"] = "uni"
result = dir(data_set)
def test_dir_unicode(ds):
ds["unicode"] = "uni"
result = dir(ds)
assert "unicode" in result


@pytest.fixture(params=[1])
def ds(request):
def ds(request, backend):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixture is less elegant than I'd hoped; it now has three variations.

I think we could attempt to define one (maybe a couple) canonical datasets, and aim to use those any new tests that aren't testing something specific (e.g. some will need all missing data, but most tests have similar requirements).

Potentially that's just the result of the create_test_data function — is that our canonical dataset?

if request.param == 1:
return Dataset(
ds = Dataset(
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
Expand All @@ -6194,21 +6190,29 @@ def ds(request):
y=range(2),
),
)

if request.param == 2:
return Dataset(
{
"z1": (["time", "y"], np.random.randn(10, 2)),
"z2": (["time"], np.random.randn(10)),
"z3": (["x", "time"], np.random.randn(8, 10)),
},
{
"x": ("x", np.linspace(0, 1.0, 8)),
"time": ("time", np.linspace(0, 1.0, 10)),
"c": ("y", ["a", "b"]),
"y": range(2),
},
elif request.param == 2:
ds = Dataset(
dict(
z1=(["time", "y"], np.random.randn(10, 2)),
z2=(["time"], np.random.randn(10)),
z3=(["x", "time"], np.random.randn(8, 10)),
),
dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
y=range(2),
),
)
elif request.param == 3:
ds = create_test_data()
else:
raise ValueError

if backend == "dask":
return ds.chunk()

return ds


def test_coarsen_absent_dims_error(ds):
Expand Down Expand Up @@ -6526,6 +6530,7 @@ def test_rolling_properties(ds):
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("key", ("z1", "z2"))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key):
bn = pytest.importorskip("bottleneck", minversion="1.1")

Expand All @@ -6551,13 +6556,15 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key):


@requires_numbagg
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp(ds):

result = ds.rolling_exp(time=10, window_type="span").mean()
assert isinstance(result, Dataset)


@requires_numbagg
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
def test_rolling_exp_keep_attrs(ds):

attrs_global = {"attrs": "global"}
Expand Down