Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion datatree/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,22 @@ def __init__(
attrs=ds._attrs,
encoding=ds._encoding,
)
self._close = ds._close
self._close = None if data is None else data._close

def close(self) -> None:
"""Release any resources linked to this object."""
if self._close is not None:
self._close()
self._close = None

for child in self._children.values():
child.close()

def __enter__(self: DataTree) -> DataTree:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

@property
def name(self) -> str | None:
Expand Down
32 changes: 15 additions & 17 deletions datatree/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine="netcdf4")

roundtrip_dt = open_datatree(filepath)
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

@requires_netCDF4
def test_netcdf_encoding(self, tmpdir, simple_datatree):
Expand All @@ -29,10 +29,9 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}

original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4")
roundtrip_dt = open_datatree(filepath)

assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]
with open_datatree(filepath) as roundtrip_dt:
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]

enc["/not/a/group"] = {"foo": "bar"}
with pytest.raises(ValueError, match="unexpected encoding group.*"):
Expand All @@ -46,8 +45,8 @@ def test_to_h5netcdf(self, tmpdir, simple_datatree):
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine="h5netcdf")

roundtrip_dt = open_datatree(filepath)
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

@requires_zarr
def test_to_zarr(self, tmpdir, simple_datatree):
Expand All @@ -57,8 +56,8 @@ def test_to_zarr(self, tmpdir, simple_datatree):
original_dt = simple_datatree
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

@requires_zarr
def test_zarr_encoding(self, tmpdir, simple_datatree):
Expand All @@ -72,10 +71,9 @@ def test_zarr_encoding(self, tmpdir, simple_datatree):
comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
original_dt.to_zarr(filepath, encoding=enc)
roundtrip_dt = open_datatree(filepath, engine="zarr")

print(roundtrip_dt["/set2/a"].encoding)
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]

enc["/not/a/group"] = {"foo": "bar"}
with pytest.raises(ValueError, match="unexpected encoding group.*"):
Expand All @@ -92,8 +90,8 @@ def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
store = ZipStore(filepath)
original_dt.to_zarr(store)

roundtrip_dt = open_datatree(store, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(store, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)

@requires_zarr
def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
Expand All @@ -107,5 +105,5 @@ def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
assert not s1zmetadata.exists()

with pytest.warns(RuntimeWarning, match="consolidated"):
roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)