Skip to content

Commit 395604d

Browse files
zarr.open should fall back to opening a group (#2310)
* zarr.open should fall back to opening a group Closes #2309 * fixup * robuster * fixup test * Consistent version error * more
1 parent 81a87d6 commit 395604d

File tree

8 files changed

+84
-6
lines changed

8 files changed

+84
-6
lines changed

src/zarr/api/asynchronous.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from zarr.core.group import AsyncGroup
1515
from zarr.core.metadata.v2 import ArrayV2Metadata
1616
from zarr.core.metadata.v3 import ArrayV3Metadata
17+
from zarr.errors import NodeTypeValidationError
1718
from zarr.storage import (
1819
StoreLike,
1920
StorePath,
@@ -247,7 +248,10 @@ async def open(
247248

248249
try:
249250
return await open_array(store=store_path, zarr_format=zarr_format, **kwargs)
250-
except KeyError:
251+
except (KeyError, NodeTypeValidationError):
252+
# KeyError for a missing key
253+
# NodeTypeValidationError for failing to parse node metadata as an array when it's
254+
# actually a group
251255
return await open_group(store=store_path, zarr_format=zarr_format, **kwargs)
252256

253257

@@ -580,6 +584,8 @@ async def open_group(
580584
meta_array : array-like, optional
581585
An array instance to use for determining arrays to create and return
582586
to users. Use `numpy.empty(())` by default.
587+
attributes : dict
588+
A dictionary of JSON-serializable values with user-defined attributes.
583589
584590
Returns
585591
-------

src/zarr/api/synchronous.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def open_group(
207207
zarr_version: ZarrFormat | None = None, # deprecated
208208
zarr_format: ZarrFormat | None = None,
209209
meta_array: Any | None = None, # not used in async api
210+
attributes: dict[str, JSON] | None = None,
210211
) -> Group:
211212
return Group(
212213
sync(
@@ -221,6 +222,7 @@ def open_group(
221222
zarr_version=zarr_version,
222223
zarr_format=zarr_format,
223224
meta_array=meta_array,
225+
attributes=attributes,
224226
)
225227
)
226228
)

src/zarr/core/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from zarr.core.metadata.v2 import ArrayV2Metadata
6868
from zarr.core.metadata.v3 import ArrayV3Metadata
6969
from zarr.core.sync import collect_aiterator, sync
70+
from zarr.errors import MetadataValidationError
7071
from zarr.registry import get_pipeline_class
7172
from zarr.storage import StoreLike, make_store_path
7273
from zarr.storage.common import StorePath, ensure_no_existing_node
@@ -144,7 +145,7 @@ async def get_array_metadata(
144145
else:
145146
zarr_format = 2
146147
else:
147-
raise ValueError(f"unexpected zarr_format: {zarr_format}")
148+
raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format)
148149

149150
metadata_dict: dict[str, Any]
150151
if zarr_format == 2:

src/zarr/core/group.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from zarr.core.config import config
3131
from zarr.core.sync import SyncMixin, sync
32+
from zarr.errors import MetadataValidationError
3233
from zarr.storage import StoreLike, make_store_path
3334
from zarr.storage.common import StorePath, ensure_no_existing_node
3435

@@ -196,7 +197,7 @@ async def open(
196197
else:
197198
zarr_format = 2
198199
else:
199-
raise ValueError(f"unexpected zarr_format: {zarr_format}")
200+
raise MetadataValidationError("zarr_format", "2, 3, or None", zarr_format)
200201

201202
if zarr_format == 2:
202203
# V2 groups are comprised of a .zgroup and .zattrs objects

src/zarr/core/metadata/v3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from zarr.core.config import config
3030
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
3131
from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE
32+
from zarr.errors import MetadataValidationError, NodeTypeValidationError
3233
from zarr.registry import get_codec_class
3334

3435
DEFAULT_DTYPE = "float64"
@@ -37,13 +38,13 @@
3738
def parse_zarr_format(data: object) -> Literal[3]:
3839
if data == 3:
3940
return 3
40-
raise ValueError(f"Invalid value. Expected 3. Got {data}.")
41+
raise MetadataValidationError("zarr_format", 3, data)
4142

4243

4344
def parse_node_type_array(data: object) -> Literal["array"]:
4445
if data == "array":
4546
return "array"
46-
raise ValueError(f"Invalid value. Expected 'array'. Got {data}.")
47+
raise NodeTypeValidationError("node_type", "array", data)
4748

4849

4950
def parse_codecs(data: object) -> tuple[Codec, ...]:

src/zarr/errors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,21 @@ class ContainsArrayAndGroupError(_BaseZarrError):
2525
)
2626

2727

28+
class MetadataValidationError(_BaseZarrError):
29+
"""An exception raised when the Zarr metadata is invalid in some way"""
30+
31+
_msg = "Invalid value for '{}'. Expected '{}'. Got '{}'."
32+
33+
34+
class NodeTypeValidationError(MetadataValidationError):
35+
"""
36+
Specialized exception when the node_type of the metadata document is incorrect..
37+
38+
This can be raised when the value is invalid or unexpected given the context,
39+
for example an 'array' node when we expected a 'group'.
40+
"""
41+
42+
2843
__all__ = [
2944
"ContainsArrayAndGroupError",
3045
"ContainsArrayError",

tests/v3/test_api.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66
from numpy.testing import assert_array_equal
77

88
import zarr
9+
import zarr.api.asynchronous
10+
import zarr.core.group
911
from zarr import Array, Group
1012
from zarr.abc.store import Store
1113
from zarr.api.synchronous import create, group, load, open, open_group, save, save_array, save_group
1214
from zarr.core.common import ZarrFormat
15+
from zarr.errors import MetadataValidationError
1316
from zarr.storage.memory import MemoryStore
1417

1518

@@ -921,3 +924,37 @@ def test_open_group_positional_args_deprecated() -> None:
921924
store = MemoryStore({}, mode="w")
922925
with pytest.warns(FutureWarning, match="pass"):
923926
open_group(store, "w")
927+
928+
929+
def test_open_falls_back_to_open_group() -> None:
930+
# https://github.com/zarr-developers/zarr-python/issues/2309
931+
store = MemoryStore(mode="w")
932+
zarr.open_group(store, attributes={"key": "value"})
933+
934+
group = zarr.open(store)
935+
assert isinstance(group, Group)
936+
assert group.attrs == {"key": "value"}
937+
938+
939+
async def test_open_falls_back_to_open_group_async() -> None:
940+
# https://github.com/zarr-developers/zarr-python/issues/2309
941+
store = MemoryStore(mode="w")
942+
await zarr.api.asynchronous.open_group(store, attributes={"key": "value"})
943+
944+
group = await zarr.api.asynchronous.open(store=store)
945+
assert isinstance(group, zarr.core.group.AsyncGroup)
946+
assert group.attrs == {"key": "value"}
947+
948+
949+
async def test_metadata_validation_error() -> None:
950+
with pytest.raises(
951+
MetadataValidationError,
952+
match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.",
953+
):
954+
await zarr.api.asynchronous.open_group(zarr_format="3.0") # type: ignore[arg-type]
955+
956+
with pytest.raises(
957+
MetadataValidationError,
958+
match="Invalid value for 'zarr_format'. Expected '2, 3, or None'. Got '3.0'.",
959+
):
960+
await zarr.api.asynchronous.open_array(shape=(1,), zarr_format="3.0") # type: ignore[arg-type]

tests/v3/test_metadata/test_v3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
default_fill_value,
2525
parse_dimension_names,
2626
parse_fill_value,
27+
parse_node_type_array,
2728
parse_zarr_format,
2829
)
2930

@@ -54,14 +55,28 @@
5455

5556
@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"])
5657
def test_parse_zarr_format_invalid(data: Any) -> None:
57-
with pytest.raises(ValueError, match=f"Invalid value. Expected 3. Got {data}"):
58+
with pytest.raises(
59+
ValueError, match=f"Invalid value for 'zarr_format'. Expected '3'. Got '{data}'."
60+
):
5861
parse_zarr_format(data)
5962

6063

6164
def test_parse_zarr_format_valid() -> None:
6265
assert parse_zarr_format(3) == 3
6366

6467

68+
@pytest.mark.parametrize("data", [None, "group"])
69+
def test_parse_node_type_arrayinvalid(data: Any) -> None:
70+
with pytest.raises(
71+
ValueError, match=f"Invalid value for 'node_type'. Expected 'array'. Got '{data}'."
72+
):
73+
parse_node_type_array(data)
74+
75+
76+
def test_parse_node_typevalid() -> None:
77+
assert parse_node_type_array("array") == "array"
78+
79+
6580
@pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}])
6681
def parse_dimension_names_invalid(data: Any) -> None:
6782
with pytest.raises(TypeError, match="Expected either None or iterable of str,"):

0 commit comments

Comments
 (0)