Skip to content

Commit 4d663cc

Browse files
Fixed consolidated Group getitem with multi-part key (#2363)
* Fixed consolidated Group getitem with multi-part key This fixes `Group.__getitem__` when indexing with a key like 'subgroup/array'. The basic idea is to rewrite the indexing operation as `group['subgroup']['array']` by splitting the key and doing each operation independently. Closes #2358 --------- Co-authored-by: Joe Hamman <joe@earthmover.io>
1 parent 3a7426f commit 4d663cc

File tree

2 files changed

+101
-39
lines changed

2 files changed

+101
-39
lines changed

src/zarr/core/group.py

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,10 @@ def _from_bytes_v2(
572572

573573
@classmethod
574574
def _from_bytes_v3(
575-
cls, store_path: StorePath, zarr_json_bytes: Buffer, use_consolidated: bool | None
575+
cls,
576+
store_path: StorePath,
577+
zarr_json_bytes: Buffer,
578+
use_consolidated: bool | None,
576579
) -> AsyncGroup:
577580
group_metadata = json.loads(zarr_json_bytes.to_bytes())
578581
if use_consolidated and group_metadata.get("consolidated_metadata") is None:
@@ -666,14 +669,33 @@ def _getitem_consolidated(
666669
# the caller needs to verify this!
667670
assert self.metadata.consolidated_metadata is not None
668671

669-
try:
670-
metadata = self.metadata.consolidated_metadata.metadata[key]
671-
except KeyError as e:
672-
# The Group Metadata has consolidated metadata, but the key
673-
# isn't present. We trust this to mean that the key isn't in
674-
# the hierarchy, and *don't* fall back to checking the store.
675-
msg = f"'{key}' not found in consolidated metadata."
676-
raise KeyError(msg) from e
672+
# we support nested getitems like group/subgroup/array
673+
indexers = key.split("/")
674+
indexers.reverse()
675+
metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata = self.metadata
676+
677+
while indexers:
678+
indexer = indexers.pop()
679+
if isinstance(metadata, ArrayV2Metadata | ArrayV3Metadata):
680+
# we've indexed into an array with group["array/subarray"]. Invalid.
681+
raise KeyError(key)
682+
if metadata.consolidated_metadata is None:
683+
# we've indexed into a group without consolidated metadata.
684+
# This isn't normal; typically, consolidated metadata
685+
# will include explicit markers for when there are no child
686+
# nodes as metadata={}.
687+
# We have some freedom in exactly how we interpret this case.
688+
# For now, we treat None as the same as {}, i.e. we don't
689+
# have any children.
690+
raise KeyError(key)
691+
try:
692+
metadata = metadata.consolidated_metadata.metadata[indexer]
693+
except KeyError as e:
694+
# The Group Metadata has consolidated metadata, but the key
695+
# isn't present. We trust this to mean that the key isn't in
696+
# the hierarchy, and *don't* fall back to checking the store.
697+
msg = f"'{key}' not found in consolidated metadata."
698+
raise KeyError(msg) from e
677699

678700
# update store_path to ensure that AsyncArray/Group.name is correct
679701
if prefix != "/":
@@ -932,11 +954,7 @@ async def create_array(
932954

933955
@deprecated("Use AsyncGroup.create_array instead.")
934956
async def create_dataset(
935-
self,
936-
name: str,
937-
*,
938-
shape: ShapeLike,
939-
**kwargs: Any,
957+
self, name: str, **kwargs: Any
940958
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
941959
"""Create an array.
942960
@@ -947,8 +965,6 @@ async def create_dataset(
947965
----------
948966
name : str
949967
Array name.
950-
shape : int or tuple of ints
951-
Array shape.
952968
kwargs : dict
953969
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.
954970
@@ -959,7 +975,7 @@ async def create_dataset(
959975
.. deprecated:: 3.0.0
960976
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.create_array` instead.
961977
"""
962-
return await self.create_array(name, shape=shape, **kwargs)
978+
return await self.create_array(name, **kwargs)
963979

964980
@deprecated("Use AsyncGroup.require_array instead.")
965981
async def require_dataset(
@@ -1081,6 +1097,8 @@ async def nmembers(
10811097
-------
10821098
count : int
10831099
"""
1100+
# check if we can use consolidated metadata, which requires that we have non-None
1101+
# consolidated metadata at all points in the hierarchy.
10841102
if self.metadata.consolidated_metadata is not None:
10851103
return len(self.metadata.consolidated_metadata.flattened_metadata)
10861104
# TODO: consider using aioitertools.builtins.sum for this
@@ -1094,7 +1112,8 @@ async def members(
10941112
self,
10951113
max_depth: int | None = 0,
10961114
) -> AsyncGenerator[
1097-
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
1115+
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
1116+
None,
10981117
]:
10991118
"""
11001119
Returns an AsyncGenerator over the arrays and groups contained in this group.
@@ -1125,12 +1144,12 @@ async def members(
11251144
async def _members(
11261145
self, max_depth: int | None, current_depth: int
11271146
) -> AsyncGenerator[
1128-
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
1147+
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
1148+
None,
11291149
]:
11301150
if self.metadata.consolidated_metadata is not None:
11311151
# we should be able to do members without any additional I/O
11321152
members = self._members_consolidated(max_depth, current_depth)
1133-
11341153
for member in members:
11351154
yield member
11361155
return
@@ -1186,7 +1205,8 @@ async def _members(
11861205
def _members_consolidated(
11871206
self, max_depth: int | None, current_depth: int, prefix: str = ""
11881207
) -> Generator[
1189-
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
1208+
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
1209+
None,
11901210
]:
11911211
consolidated_metadata = self.metadata.consolidated_metadata
11921212

@@ -1271,7 +1291,11 @@ async def full(
12711291
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
12721292
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
12731293
return await async_api.full(
1274-
shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs
1294+
shape=shape,
1295+
fill_value=fill_value,
1296+
store=self.store_path,
1297+
path=name,
1298+
**kwargs,
12751299
)
12761300

12771301
async def empty_like(
@@ -1627,13 +1651,7 @@ def create_dataset(self, name: str, **kwargs: Any) -> Array:
16271651
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))
16281652

16291653
@deprecated("Use Group.require_array instead.")
1630-
def require_dataset(
1631-
self,
1632-
name: str,
1633-
*,
1634-
shape: ShapeLike,
1635-
**kwargs: Any,
1636-
) -> Array:
1654+
def require_dataset(self, name: str, **kwargs: Any) -> Array:
16371655
"""Obtain an array, creating if it doesn't exist.
16381656
16391657
Arrays are known as "datasets" in HDF5 terminology. For compatibility
@@ -1660,15 +1678,9 @@ def require_dataset(
16601678
.. deprecated:: 3.0.0
16611679
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.
16621680
"""
1663-
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
1681+
return Array(self._sync(self._async_group.require_array(name, **kwargs)))
16641682

1665-
def require_array(
1666-
self,
1667-
name: str,
1668-
*,
1669-
shape: ShapeLike,
1670-
**kwargs: Any,
1671-
) -> Array:
1683+
def require_array(self, name: str, **kwargs: Any) -> Array:
16721684
"""Obtain an array, creating if it doesn't exist.
16731685
16741686
@@ -1690,7 +1702,7 @@ def require_array(
16901702
-------
16911703
a : Array
16921704
"""
1693-
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
1705+
return Array(self._sync(self._async_group.require_array(name, **kwargs)))
16941706

16951707
def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
16961708
return Array(self._sync(self._async_group.empty(name=name, shape=shape, **kwargs)))

tests/v3/test_group.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,18 +306,53 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool
306306
group = Group.from_store(store, zarr_format=zarr_format)
307307
subgroup = group.create_group(name="subgroup")
308308
subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
309+
subsubarray = subgroup.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
309310

310311
if consolidated:
311312
group = zarr.api.synchronous.consolidate_metadata(store=store, zarr_format=zarr_format)
313+
# we're going to assume that `group.metadata` is correct, and reuse that to focus
314+
# on indexing in this test. Other tests verify the correctness of group.metadata
312315
object.__setattr__(
313-
subgroup.metadata, "consolidated_metadata", ConsolidatedMetadata(metadata={})
316+
subgroup.metadata,
317+
"consolidated_metadata",
318+
ConsolidatedMetadata(
319+
metadata={"subarray": group.metadata.consolidated_metadata.metadata["subarray"]}
320+
),
314321
)
315322

316323
assert group["subgroup"] == subgroup
317324
assert group["subarray"] == subarray
325+
assert group["subgroup"]["subarray"] == subsubarray
326+
assert group["subgroup/subarray"] == subsubarray
327+
318328
with pytest.raises(KeyError):
319329
group["nope"]
320330

331+
with pytest.raises(KeyError, match="subarray/subsubarray"):
332+
group["subarray/subsubarray"]
333+
334+
# Now test the mixed case
335+
if consolidated:
336+
object.__setattr__(
337+
group.metadata.consolidated_metadata.metadata["subgroup"],
338+
"consolidated_metadata",
339+
None,
340+
)
341+
342+
# test the implementation directly
343+
with pytest.raises(KeyError):
344+
group._async_group._getitem_consolidated(
345+
group.store_path, "subgroup/subarray", prefix="/"
346+
)
347+
348+
with pytest.raises(KeyError):
349+
# We've chosen to trust the consolidted metadata, which doesn't
350+
# contain this array
351+
group["subgroup/subarray"]
352+
353+
with pytest.raises(KeyError, match="subarray/subsubarray"):
354+
group["subarray/subsubarray"]
355+
321356

322357
def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
323358
group = Group.from_store(store, zarr_format=zarr_format)
@@ -1008,6 +1043,21 @@ async def test_group_members_async(store: Store, consolidated_metadata: bool) ->
10081043
with pytest.raises(ValueError, match="max_depth"):
10091044
[x async for x in group.members(max_depth=-1)]
10101045

1046+
if consolidated_metadata:
1047+
# test for mixed known and unknown metadata.
1048+
# For now, we trust the consolidated metadata.
1049+
object.__setattr__(
1050+
group.metadata.consolidated_metadata.metadata["g0"].consolidated_metadata.metadata[
1051+
"g1"
1052+
],
1053+
"consolidated_metadata",
1054+
None,
1055+
)
1056+
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
1057+
assert len(all_children) == 4
1058+
nmembers = await group.nmembers(max_depth=None)
1059+
assert nmembers == 4
1060+
10111061

10121062
async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
10131063
root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format)

0 commit comments

Comments
 (0)