Skip to content

Commit 5a134bf

Browse files
authored
fix: add get(key, default) method to Group APIs (#2311)
* fix: add get(key, default) method to Group APIs * add test
1 parent 7e2be57 commit 5a134bf

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/zarr/core/group.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from dataclasses import asdict, dataclass, field, fields, replace
7-
from typing import TYPE_CHECKING, Literal, cast, overload
7+
from typing import TYPE_CHECKING, Literal, TypeVar, cast, overload
88

99
import numpy as np
1010
import numpy.typing as npt
@@ -42,6 +42,8 @@
4242

4343
logger = logging.getLogger("zarr.group")
4444

45+
DefaultT = TypeVar("DefaultT")
46+
4547

4648
def parse_zarr_format(data: Any) -> ZarrFormat:
4749
if data in (2, 3):
@@ -290,6 +292,28 @@ async def delitem(self, key: str) -> None:
290292
else:
291293
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")
292294

295+
async def get(
296+
self, key: str, default: DefaultT | None = None
297+
) -> AsyncArray | AsyncGroup | DefaultT | None:
298+
"""Obtain a group member, returning default if not found.
299+
300+
Parameters
301+
----------
302+
key : string
303+
Group member name.
304+
default : object
305+
Default value to return if key is not found (default: None).
306+
307+
Returns
308+
-------
309+
object
310+
Group member (AsyncArray or AsyncGroup) or default if not found.
311+
"""
312+
try:
313+
return await self.getitem(key)
314+
except KeyError:
315+
return default
316+
293317
async def _save_metadata(self, ensure_parents: bool = False) -> None:
294318
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
295319
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
@@ -828,6 +852,26 @@ def __getitem__(self, path: str) -> Array | Group:
828852
else:
829853
return Group(obj)
830854

855+
def get(self, path: str, default: DefaultT | None = None) -> Array | Group | DefaultT | None:
856+
"""Obtain a group member, returning default if not found.
857+
858+
Parameters
859+
----------
860+
key : string
861+
Group member name.
862+
default : object
863+
Default value to return if key is not found (default: None).
864+
865+
Returns
866+
-------
867+
object
868+
Group member (Array or Group) or default if not found.
869+
"""
870+
try:
871+
return self[path]
872+
except KeyError:
873+
return default
874+
831875
def __delitem__(self, key: str) -> None:
832876
self._sync(self._async_group.delitem(key))
833877

tests/v3/test_group.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,25 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat) -> None:
292292
group["nope"]
293293

294294

295+
def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
296+
group = Group.from_store(store, zarr_format=zarr_format)
297+
298+
# default behavior
299+
result = group.get("subgroup")
300+
assert result is None
301+
302+
# custom default
303+
result = group.get("subgroup", 8)
304+
assert result == 8
305+
306+
# now with a group
307+
subgroup = group.require_group("subgroup")
308+
subgroup.attrs["foo"] = "bar"
309+
310+
result = group.get("subgroup", 8)
311+
assert result.attrs["foo"] == "bar"
312+
313+
295314
def test_group_delitem(store: Store, zarr_format: ZarrFormat) -> None:
296315
"""
297316
Test the `Group.__delitem__` method.

0 commit comments

Comments
 (0)