Skip to content

Commit

Permalink
cleaning up new test
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 2, 2024
1 parent e2e34d5 commit 9ec6fb9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
20 changes: 20 additions & 0 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ def d(request, backend, type) -> DataArray | Dataset:
raise ValueError


@pytest.fixture
def byte_attrs_dataset():
"""For testing issue #9407"""
null_byte = b"\x00"
other_bytes = bytes(range(1, 256))
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
ds["x"].attrs["null_byte"] = null_byte
ds["x"].attrs["other_bytes"] = other_bytes

expected = ds.copy()
expected["x"].attrs["null_byte"] = ""
expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace")

return {
"input": ds,
"expected": expected,
"h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*",
}


@pytest.fixture(scope="module")
def create_test_datatree():
"""
Expand Down
27 changes: 10 additions & 17 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,23 +1404,12 @@ def test_refresh_from_disk(self) -> None:
a.close()
b.close()

def test_byte_attrs(self) -> None:
with create_tmp_file() as tmp_file:
try:
null_byte = b"\x00"
other_bytes = bytes(range(1, 256))
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
ds["x"].attrs["null_byte"] = null_byte
ds["x"].attrs["other_bytes"] = other_bytes
self.save(ds, tmp_file)
except ValueError:
assert self.engine == "h5netcdf"
else:
with self.open(tmp_file) as ds_out:
assert ds_out["x"].attrs["null_byte"] == ""
assert ds_out["x"].attrs["other_bytes"] == other_bytes.decode(
errors="replace"
)
def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
# test for issue #9407
input = byte_attrs_dataset["input"]
expected = byte_attrs_dataset["expected"]
with self.roundtrip(input) as actual:
assert_identical(actual, expected)


_counter = itertools.count()
Expand Down Expand Up @@ -3879,6 +3868,10 @@ def test_decode_utf8_warning(self) -> None:
assert ds.title == title
assert "attribute 'title' of h5netcdf object '/'" in str(w[0].message)

def test_byte_attrs(self, byte_attrs_dataset: dict[str, Any]) -> None:
with pytest.raises(ValueError, match=byte_attrs_dataset["h5netcdf_error"]):
super().test_byte_attrs(byte_attrs_dataset)


@requires_h5netcdf
@requires_netCDF4
Expand Down

0 comments on commit 9ec6fb9

Please sign in to comment.