Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix writing of DataTree subgroups to zarr or netCDF #9677

Merged
merged 9 commits into from
Nov 4, 2024
9 changes: 8 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ New Features
~~~~~~~~~~~~
- Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`).
By `Sam Levang <https://github.com/slevang>`_.
- Added ``write_inherited_coords`` option to :py:meth:`DataTree.to_netcdf`
and :py:meth:`DataTree.to_zarr` (:pull:`9677`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
Expand All @@ -42,7 +45,11 @@ Deprecations
Bug fixes
~~~~~~~~~

- Fix inadvertent deep-copying of child data in DataTree.
- Fix inadvertent deep-copying of child data in DataTree (:issue:`9683`,
:pull:`9684`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Avoid including parent groups when writing DataTree subgroups to Zarr or
netCDF (:pull:`9682`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`).
By `Pascal Bourgault <https://github.com/aulemahal>`_.
Expand Down
14 changes: 14 additions & 0 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def to_netcdf(
format: T_DataTreeNetcdfTypes | None = None,
engine: T_DataTreeNetcdfEngine | None = None,
group: str | None = None,
write_inherited_coords: bool = False,
compute: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -1609,6 +1610,11 @@ def to_netcdf(
group : str, optional
Path to the netCDF4 group in the given file to open as the root group
of the ``DataTree``. Currently, specifying a group is not supported.
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later.
Expand All @@ -1632,6 +1638,7 @@ def to_netcdf(
format=format,
engine=engine,
group=group,
write_inherited_coords=write_inherited_coords,
compute=compute,
**kwargs,
)
Expand All @@ -1643,6 +1650,7 @@ def to_zarr(
encoding=None,
consolidated: bool = True,
group: str | None = None,
write_inherited_coords: bool = False,
compute: Literal[True] = True,
**kwargs,
):
Expand All @@ -1668,6 +1676,11 @@ def to_zarr(
after writing metadata for all groups.
group : str, optional
Group path. (a.k.a. `path` in zarr terminology.)
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
compute : bool, default: True
If true compute immediately, otherwise return a
``dask.delayed.Delayed`` object that can be computed later. Metadata
Expand All @@ -1690,6 +1703,7 @@ def to_zarr(
encoding=encoding,
consolidated=consolidated,
group=group,
write_inherited_coords=write_inherited_coords,
compute=compute,
**kwargs,
)
Expand Down
106 changes: 28 additions & 78 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,15 @@

from collections.abc import Mapping, MutableMapping
from os import PathLike
from typing import TYPE_CHECKING, Any, Literal, get_args
from typing import Any, Literal, get_args

from xarray.core.datatree import DataTree
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes

if TYPE_CHECKING:
from h5netcdf.legacyapi import Dataset as h5Dataset
from netCDF4 import Dataset as ncDataset

T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
T_DataTreeNetcdfTypes = Literal["NETCDF4"]


def _get_nc_dataset_class(
engine: T_DataTreeNetcdfEngine | None,
) -> type[ncDataset] | type[h5Dataset]:
if engine == "netcdf4":
from netCDF4 import Dataset as ncDataset

return ncDataset
if engine == "h5netcdf":
from h5netcdf.legacyapi import Dataset as h5Dataset

return h5Dataset
if engine is None:
try:
from netCDF4 import Dataset as ncDataset

return ncDataset
except ImportError:
from h5netcdf.legacyapi import Dataset as h5Dataset

return h5Dataset
raise ValueError(f"unsupported engine: {engine}")


def _create_empty_netcdf_group(
filename: str | PathLike,
group: str,
mode: NetcdfWriteModes,
engine: T_DataTreeNetcdfEngine | None,
) -> None:
ncDataset = _get_nc_dataset_class(engine)

with ncDataset(filename, mode=mode) as rootgrp:
rootgrp.createGroup(group)


def _datatree_to_netcdf(
dt: DataTree,
filepath: str | PathLike,
Expand All @@ -59,6 +20,7 @@ def _datatree_to_netcdf(
format: T_DataTreeNetcdfTypes | None = None,
engine: T_DataTreeNetcdfEngine | None = None,
group: str | None = None,
write_inherited_coords: bool = False,
compute: bool = True,
**kwargs,
) -> None:
Expand Down Expand Up @@ -97,41 +59,31 @@ def _datatree_to_netcdf(
unlimited_dims = {}

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
else:
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_netcdf(
filepath,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
unlimited_dims=unlimited_dims.get(node.path),
engine=engine,
format=format,
compute=compute,
**kwargs,
)
mode = "a"


def _create_empty_zarr_group(
store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
):
import zarr

root = zarr.open_group(store, mode=mode)
root.create_group(group, overwrite=True)


def _datatree_to_zarr(
dt: DataTree,
store: MutableMapping | str | PathLike[str],
mode: ZarrWriteModes = "w-",
encoding: Mapping[str, Any] | None = None,
consolidated: bool = True,
group: str | None = None,
write_inherited_coords: bool = False,
compute: Literal[True] = True,
**kwargs,
):
Expand Down Expand Up @@ -163,19 +115,17 @@ def _datatree_to_zarr(
)

for node in dt.subtree:
ds = node.to_dataset(inherit=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
else:
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
at_root = node is dt
ds = node.to_dataset(inherit=write_inherited_coords or at_root)
group_path = None if at_root else "/" + node.relative_to(dt)
ds.to_zarr(
store,
group=group_path,
mode=mode,
encoding=encoding.get(node.path),
consolidated=False,
**kwargs,
)
if "w" in mode:
mode = "a"

Expand Down
74 changes: 74 additions & 0 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_netcdf(filepath, engine=self.engine)

with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)


@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
Expand Down Expand Up @@ -556,3 +574,59 @@ def test_open_groups_chunks(self, tmpdir) -> None:

for ds in dict_of_datasets.values():
ds.close()

def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]

expected_dt = original_dt.copy()
expected_dt.name = None

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)

def test_write_inherited_coords_false(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=False)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)

expected_child = original_dt.children["child"].copy(inherit=False)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)

def test_write_inherited_coords_true(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)

filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=True)

with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)

expected_child = original_dt.children["child"].copy(inherit=True)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)
Loading