Skip to content

Commit 3b793c1

Browse files
authored
[v3] h5py compat methods on Group (#2128)
* feature(h5compat): add create_dataset, require_dataset, require_group, and require_gruops methods to group class * make mypy happy * doc fixes * write initial tests * more tests * add deprecation warnings * add deprecation warnings * switch up test
1 parent 60b4f57 commit 3b793c1

File tree

3 files changed

+337
-2
lines changed

3 files changed

+337
-2
lines changed

src/zarr/core/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ZATTRS_JSON = ".zattrs"
2929

3030
BytesLike = bytes | bytearray | memoryview
31+
ShapeLike = tuple[int, ...] | int
3132
ChunkCoords = tuple[int, ...]
3233
ChunkCoordsLike = Iterable[int]
3334
ZarrFormat = Literal[2, 3]

src/zarr/core/group.py

Lines changed: 249 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import asdict, dataclass, field, replace
88
from typing import TYPE_CHECKING, Literal, cast, overload
99

10+
import numpy as np
1011
import numpy.typing as npt
1112
from typing_extensions import deprecated
1213

@@ -25,6 +26,7 @@
2526
ZGROUP_JSON,
2627
ChunkCoords,
2728
ZarrFormat,
29+
parse_shapelike,
2830
)
2931
from zarr.core.config import config
3032
from zarr.core.sync import SyncMixin, sync
@@ -250,7 +252,7 @@ async def getitem(
250252
if zarray is not None:
251253
# TODO: update this once the V2 array support is part of the primary array class
252254
zarr_json = {**zarray, "attributes": zattrs}
253-
return AsyncArray.from_dict(store_path, zarray)
255+
return AsyncArray.from_dict(store_path, zarr_json)
254256
else:
255257
zgroup = (
256258
json.loads(zgroup_bytes.to_bytes())
@@ -324,6 +326,42 @@ async def create_group(
324326
zarr_format=self.metadata.zarr_format,
325327
)
326328

329+
async def require_group(self, name: str, overwrite: bool = False) -> AsyncGroup:
330+
"""Obtain a sub-group, creating one if it doesn't exist.
331+
332+
Parameters
333+
----------
334+
name : string
335+
Group name.
336+
overwrite : bool, optional
337+
Overwrite any existing group with given `name` if present.
338+
339+
Returns
340+
-------
341+
g : AsyncGroup
342+
"""
343+
if overwrite:
344+
# TODO: check that exists_ok=True errors if an array exists where the group is being created
345+
grp = await self.create_group(name, exists_ok=True)
346+
else:
347+
try:
348+
item: AsyncGroup | AsyncArray = await self.getitem(name)
349+
if not isinstance(item, AsyncGroup):
350+
raise TypeError(
351+
f"Incompatible object ({item.__class__.__name__}) already exists"
352+
)
353+
assert isinstance(item, AsyncGroup) # make mypy happy
354+
grp = item
355+
except KeyError:
356+
grp = await self.create_group(name)
357+
return grp
358+
359+
async def require_groups(self, *names: str) -> tuple[AsyncGroup, ...]:
360+
"""Convenience method to require multiple groups in a single call."""
361+
if not names:
362+
return ()
363+
return tuple(await asyncio.gather(*(self.require_group(name) for name in names)))
364+
327365
async def create_array(
328366
self,
329367
name: str,
@@ -413,6 +451,117 @@ async def create_array(
413451
data=data,
414452
)
415453

454+
@deprecated("Use AsyncGroup.create_array instead.")
455+
async def create_dataset(self, name: str, **kwargs: Any) -> AsyncArray:
456+
"""Create an array.
457+
458+
Arrays are known as "datasets" in HDF5 terminology. For compatibility
459+
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.require_dataset` method.
460+
461+
Parameters
462+
----------
463+
name : string
464+
Array name.
465+
kwargs : dict
466+
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.
467+
468+
Returns
469+
-------
470+
a : AsyncArray
471+
472+
.. deprecated:: 3.0.0
473+
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.create_array` instead.
474+
"""
475+
return await self.create_array(name, **kwargs)
476+
477+
@deprecated("Use AsyncGroup.require_array instead.")
478+
async def require_dataset(
479+
self,
480+
name: str,
481+
*,
482+
shape: ChunkCoords,
483+
dtype: npt.DTypeLike = None,
484+
exact: bool = False,
485+
**kwargs: Any,
486+
) -> AsyncArray:
487+
"""Obtain an array, creating if it doesn't exist.
488+
489+
Arrays are known as "datasets" in HDF5 terminology. For compatibility
490+
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method.
491+
492+
Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`.
493+
494+
Parameters
495+
----------
496+
name : string
497+
Array name.
498+
shape : int or tuple of ints
499+
Array shape.
500+
dtype : string or dtype, optional
501+
NumPy dtype.
502+
exact : bool, optional
503+
If True, require `dtype` to match exactly. If false, require
504+
`dtype` can be cast from array dtype.
505+
506+
Returns
507+
-------
508+
a : AsyncArray
509+
510+
.. deprecated:: 3.0.0
511+
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.require_dataset` instead.
512+
"""
513+
return await self.require_array(name, shape=shape, dtype=dtype, exact=exact, **kwargs)
514+
515+
async def require_array(
516+
self,
517+
name: str,
518+
*,
519+
shape: ChunkCoords,
520+
dtype: npt.DTypeLike = None,
521+
exact: bool = False,
522+
**kwargs: Any,
523+
) -> AsyncArray:
524+
"""Obtain an array, creating if it doesn't exist.
525+
526+
Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`.
527+
528+
Parameters
529+
----------
530+
name : string
531+
Array name.
532+
shape : int or tuple of ints
533+
Array shape.
534+
dtype : string or dtype, optional
535+
NumPy dtype.
536+
exact : bool, optional
537+
If True, require `dtype` to match exactly. If false, require
538+
`dtype` can be cast from array dtype.
539+
540+
Returns
541+
-------
542+
a : AsyncArray
543+
"""
544+
try:
545+
ds = await self.getitem(name)
546+
if not isinstance(ds, AsyncArray):
547+
raise TypeError(f"Incompatible object ({ds.__class__.__name__}) already exists")
548+
549+
shape = parse_shapelike(shape)
550+
if shape != ds.shape:
551+
raise TypeError(f"Incompatible shape ({ds.shape} vs {shape})")
552+
553+
dtype = np.dtype(dtype)
554+
if exact:
555+
if ds.dtype != dtype:
556+
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
557+
else:
558+
if not np.can_cast(ds.dtype, dtype):
559+
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
560+
except KeyError:
561+
ds = await self.create_array(name, shape=shape, dtype=dtype, **kwargs)
562+
563+
return ds
564+
416565
async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
417566
# metadata.attributes is "frozen" so we simply clear and update the dict
418567
self.metadata.attributes.clear()
@@ -612,8 +761,9 @@ def create(
612761
def open(
613762
cls,
614763
store: StoreLike,
764+
zarr_format: Literal[2, 3, None] = 3,
615765
) -> Group:
616-
obj = sync(AsyncGroup.open(store))
766+
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
617767
return cls(obj)
618768

619769
def __getitem__(self, path: str) -> Array | Group:
@@ -717,6 +867,26 @@ def tree(self, expand: bool = False, level: int | None = None) -> Any:
717867
def create_group(self, name: str, **kwargs: Any) -> Group:
718868
return Group(self._sync(self._async_group.create_group(name, **kwargs)))
719869

870+
def require_group(self, name: str, **kwargs: Any) -> Group:
871+
"""Obtain a sub-group, creating one if it doesn't exist.
872+
873+
Parameters
874+
----------
875+
name : string
876+
Group name.
877+
overwrite : bool, optional
878+
Overwrite any existing group with given `name` if present.
879+
880+
Returns
881+
-------
882+
g : Group
883+
"""
884+
return Group(self._sync(self._async_group.require_group(name, **kwargs)))
885+
886+
def require_groups(self, *names: str) -> tuple[Group, ...]:
887+
"""Convenience method to require multiple groups in a single call."""
888+
return tuple(map(Group, self._sync(self._async_group.require_groups(*names))))
889+
720890
def create_array(
721891
self,
722892
name: str,
@@ -811,6 +981,83 @@ def create_array(
811981
)
812982
)
813983

984+
@deprecated("Use Group.create_array instead.")
985+
def create_dataset(self, name: str, **kwargs: Any) -> Array:
986+
"""Create an array.
987+
988+
Arrays are known as "datasets" in HDF5 terminology. For compatibility
989+
with h5py, Zarr groups also implement the :func:`zarr.Group.require_dataset` method.
990+
991+
Parameters
992+
----------
993+
name : string
994+
Array name.
995+
kwargs : dict
996+
Additional arguments passed to :func:`zarr.Group.create_array`
997+
998+
Returns
999+
-------
1000+
a : Array
1001+
1002+
.. deprecated:: 3.0.0
1003+
The h5py compatibility methods will be removed in 3.1.0. Use `Group.create_array` instead.
1004+
"""
1005+
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))
1006+
1007+
@deprecated("Use Group.require_array instead.")
1008+
def require_dataset(self, name: str, **kwargs: Any) -> Array:
1009+
"""Obtain an array, creating if it doesn't exist.
1010+
1011+
Arrays are known as "datasets" in HDF5 terminology. For compatibility
1012+
with h5py, Zarr groups also implement the :func:`zarr.Group.create_dataset` method.
1013+
1014+
Other `kwargs` are as per :func:`zarr.Group.create_dataset`.
1015+
1016+
Parameters
1017+
----------
1018+
name : string
1019+
Array name.
1020+
shape : int or tuple of ints
1021+
Array shape.
1022+
dtype : string or dtype, optional
1023+
NumPy dtype.
1024+
exact : bool, optional
1025+
If True, require `dtype` to match exactly. If false, require
1026+
`dtype` can be cast from array dtype.
1027+
1028+
Returns
1029+
-------
1030+
a : Array
1031+
1032+
.. deprecated:: 3.0.0
1033+
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.
1034+
"""
1035+
return Array(self._sync(self._async_group.require_array(name, **kwargs)))
1036+
1037+
def require_array(self, name: str, **kwargs: Any) -> Array:
1038+
"""Obtain an array, creating if it doesn't exist.
1039+
1040+
1041+
Other `kwargs` are as per :func:`zarr.Group.create_array`.
1042+
1043+
Parameters
1044+
----------
1045+
name : string
1046+
Array name.
1047+
shape : int or tuple of ints
1048+
Array shape.
1049+
dtype : string or dtype, optional
1050+
NumPy dtype.
1051+
exact : bool, optional
1052+
If True, require `dtype` to match exactly. If false, require
1053+
`dtype` can be cast from array dtype.
1054+
1055+
Returns
1056+
-------
1057+
a : Array
1058+
"""
1059+
return Array(self._sync(self._async_group.require_array(name, **kwargs)))
1060+
8141061
def empty(self, **kwargs: Any) -> Array:
8151062
return Array(self._sync(self._async_group.empty(**kwargs)))
8161063

0 commit comments

Comments
 (0)