Skip to content

Commit aeaa082

Browse files
open_groups for zarr backends (#9469)
* open groups zarr initial commit * added tests * Added requested changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * TypeHint for zarr groups * update for parent --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 637f820 commit aeaa082

File tree

2 files changed

+225
-136
lines changed

2 files changed

+225
-136
lines changed

xarray/backends/zarr.py

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from xarray.backends.store import StoreBackendEntrypoint
2222
from xarray.core import indexing
23+
from xarray.core.treenode import NodePath
2324
from xarray.core.types import ZarrWriteModes
2425
from xarray.core.utils import (
2526
FrozenDict,
@@ -33,6 +34,8 @@
3334
if TYPE_CHECKING:
3435
from io import BufferedIOBase
3536

37+
from zarr import Group as ZarrGroup
38+
3639
from xarray.backends.common import AbstractDataStore
3740
from xarray.core.dataset import Dataset
3841
from xarray.core.datatree import DataTree
@@ -1218,66 +1221,86 @@ def open_datatree(
12181221
zarr_version=None,
12191222
**kwargs,
12201223
) -> DataTree:
1221-
from xarray.backends.api import open_dataset
12221224
from xarray.core.datatree import DataTree
1225+
1226+
filename_or_obj = _normalize_path(filename_or_obj)
1227+
groups_dict = self.open_groups_as_dict(filename_or_obj, **kwargs)
1228+
1229+
return DataTree.from_dict(groups_dict)
1230+
1231+
def open_groups_as_dict(
1232+
self,
1233+
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
1234+
*,
1235+
mask_and_scale=True,
1236+
decode_times=True,
1237+
concat_characters=True,
1238+
decode_coords=True,
1239+
drop_variables: str | Iterable[str] | None = None,
1240+
use_cftime=None,
1241+
decode_timedelta=None,
1242+
group: str | Iterable[str] | Callable | None = None,
1243+
mode="r",
1244+
synchronizer=None,
1245+
consolidated=None,
1246+
chunk_store=None,
1247+
storage_options=None,
1248+
stacklevel=3,
1249+
zarr_version=None,
1250+
**kwargs,
1251+
) -> dict[str, Dataset]:
1252+
12231253
from xarray.core.treenode import NodePath
12241254

12251255
filename_or_obj = _normalize_path(filename_or_obj)
1256+
1257+
# Check for a group and make it a parent if it exists
12261258
if group:
1227-
parent = NodePath("/") / NodePath(group)
1228-
stores = ZarrStore.open_store(
1229-
filename_or_obj,
1230-
group=parent,
1231-
mode=mode,
1232-
synchronizer=synchronizer,
1233-
consolidated=consolidated,
1234-
consolidate_on_close=False,
1235-
chunk_store=chunk_store,
1236-
storage_options=storage_options,
1237-
stacklevel=stacklevel + 1,
1238-
zarr_version=zarr_version,
1239-
)
1240-
if not stores:
1241-
ds = open_dataset(
1242-
filename_or_obj, group=parent, engine="zarr", **kwargs
1243-
)
1244-
return DataTree.from_dict({str(parent): ds})
1259+
parent = str(NodePath("/") / NodePath(group))
12451260
else:
1246-
parent = NodePath("/")
1247-
stores = ZarrStore.open_store(
1248-
filename_or_obj,
1249-
group=parent,
1250-
mode=mode,
1251-
synchronizer=synchronizer,
1252-
consolidated=consolidated,
1253-
consolidate_on_close=False,
1254-
chunk_store=chunk_store,
1255-
storage_options=storage_options,
1256-
stacklevel=stacklevel + 1,
1257-
zarr_version=zarr_version,
1258-
)
1259-
ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs)
1260-
tree_root = DataTree.from_dict({str(parent): ds})
1261+
parent = str(NodePath("/"))
1262+
1263+
stores = ZarrStore.open_store(
1264+
filename_or_obj,
1265+
group=parent,
1266+
mode=mode,
1267+
synchronizer=synchronizer,
1268+
consolidated=consolidated,
1269+
consolidate_on_close=False,
1270+
chunk_store=chunk_store,
1271+
storage_options=storage_options,
1272+
stacklevel=stacklevel + 1,
1273+
zarr_version=zarr_version,
1274+
)
1275+
1276+
groups_dict = {}
1277+
12611278
for path_group, store in stores.items():
1262-
ds = open_dataset(
1263-
filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
1264-
)
1265-
new_node = DataTree(name=NodePath(path_group).name, dataset=ds)
1266-
tree_root._set_item(
1267-
path_group,
1268-
new_node,
1269-
allow_overwrite=False,
1270-
new_nodes_along_path=True,
1271-
)
1272-
return tree_root
1279+
store_entrypoint = StoreBackendEntrypoint()
1280+
1281+
with close_on_error(store):
1282+
group_ds = store_entrypoint.open_dataset(
1283+
store,
1284+
mask_and_scale=mask_and_scale,
1285+
decode_times=decode_times,
1286+
concat_characters=concat_characters,
1287+
decode_coords=decode_coords,
1288+
drop_variables=drop_variables,
1289+
use_cftime=use_cftime,
1290+
decode_timedelta=decode_timedelta,
1291+
)
1292+
group_name = str(NodePath(path_group))
1293+
groups_dict[group_name] = group_ds
1294+
1295+
return groups_dict
12731296

12741297

1275-
def _iter_zarr_groups(root, parent="/"):
1276-
from xarray.core.treenode import NodePath
1298+
def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
12771299

1278-
parent = NodePath(parent)
1300+
parent_nodepath = NodePath(parent)
1301+
yield str(parent_nodepath)
12791302
for path, group in root.groups():
1280-
gpath = parent / path
1303+
gpath = parent_nodepath / path
12811304
yield str(gpath)
12821305
yield from _iter_zarr_groups(group, parent=gpath)
12831306

0 commit comments

Comments
 (0)