Skip to content

Commit 251253e

Browse files
committed
Round trip serialization for array metadata v2/v3
1 parent 370eb8b commit 251253e

File tree

5 files changed

+29
-170
lines changed

5 files changed

+29
-170
lines changed

changes/2802.fix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix `fill_value` serialization for `NaN` in `ArrayV2Metadata` and add property-based testing of round-trip serialization

src/zarr/core/metadata/v2.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
170170
if dtype.kind in "SV":
171171
fill_value_encoded = _data.get("fill_value")
172172
if fill_value_encoded is not None:
173-
fill_value = base64.standard_b64decode(fill_value_encoded)
173+
fill_value: Any = base64.standard_b64decode(fill_value_encoded)
174174
_data["fill_value"] = fill_value
175175
else:
176176
fill_value = _data.get("fill_value")
@@ -180,13 +180,11 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
180180
_data["fill_value"] = np.array("NaT", dtype=dtype)[()]
181181
else:
182182
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183-
elif dtype.kind == "c" and isinstance(fill_value, list):
184-
if len(fill_value) == 2:
185-
val = complex(float(fill_value[0]), float(fill_value[1]))
186-
_data["fill_value"] = np.array(val, dtype=dtype)[()]
187-
elif dtype.kind in "f" and isinstance(fill_value, str):
188-
if fill_value in {"NaN", "Infinity", "-Infinity"}:
189-
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
183+
elif dtype.kind == "c" and isinstance(fill_value, list) and len(fill_value) == 2:
184+
val = complex(float(fill_value[0]), float(fill_value[1]))
185+
_data["fill_value"] = np.array(val, dtype=dtype)[()]
186+
elif dtype.kind in "f" and fill_value in {"NaN", "Infinity", "-Infinity"}:
187+
_data["fill_value"] = np.array(fill_value, dtype=dtype)[()]
190188
# zarr v2 allowed arbitrary keys in the metadata.
191189
# Filter the keys to only those expected by the constructor.
192190
expected = {x.name for x in fields(cls)}
@@ -196,21 +194,22 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
196194
return cls(**_data)
197195

198196
def to_dict(self) -> dict[str, JSON]:
199-
def _sanitize_fill_value(fv: Any):
197+
def _sanitize_fill_value(fv: Any) -> JSON:
200198
if fv is None:
201199
return fv
202200
elif isinstance(fv, np.datetime64):
203201
if np.isnat(fv):
204202
return "NaT"
205203
return np.datetime_as_string(fv)
206204
elif isinstance(fv, numbers.Real):
207-
if np.isnan(fv):
205+
float_fv = float(fv)
206+
if np.isnan(float_fv):
208207
fv = "NaN"
209-
elif np.isinf(fv):
210-
fv = "Infinity" if fv > 0 else "-Infinity"
208+
elif np.isinf(float_fv):
209+
fv = "Infinity" if float_fv > 0 else "-Infinity"
211210
elif isinstance(fv, numbers.Complex):
212211
fv = [_sanitize_fill_value(fv.real), _sanitize_fill_value(fv.imag)]
213-
return fv
212+
return cast(JSON, fv)
214213

215214
zarray_dict = super().to_dict()
216215

src/zarr/testing/stateful.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_group(self, name: str, data: DataObject) -> None:
8585
@rule(
8686
data=st.data(),
8787
name=node_names,
88-
array_and_chunks=np_array_and_chunks(nparrays=numpy_arrays(zarr_formats=st.just(3))),
88+
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
8989
)
9090
def add_array(
9191
self,

src/zarr/testing/strategies.py

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import hypothesis.extra.numpy as npst
55
import hypothesis.strategies as st
6-
import numcodecs
76
import numpy as np
87
from hypothesis import assume, given, settings # noqa: F401
98
from hypothesis.strategies import SearchStrategy
@@ -345,136 +344,3 @@ def make_request(start: int, length: int) -> RangeByteRequest:
345344
)
346345
key_tuple = st.tuples(keys, byte_ranges)
347346
return st.lists(key_tuple, min_size=1, max_size=10)
348-
349-
350-
def simple_text():
351-
"""A strategy for generating simple text strings."""
352-
return st.text(st.characters(min_codepoint=32, max_codepoint=126), min_size=1, max_size=10)
353-
354-
355-
def simple_attrs():
356-
"""A strategy for generating simple attribute dictionaries."""
357-
return st.dictionaries(
358-
simple_text(),
359-
st.one_of(
360-
st.integers(),
361-
st.floats(allow_nan=False, allow_infinity=False),
362-
st.booleans(),
363-
simple_text(),
364-
),
365-
)
366-
367-
368-
def array_shapes(min_dims=1, max_dims=3, max_len=100):
369-
"""A strategy for generating array shapes."""
370-
return st.lists(
371-
st.integers(min_value=1, max_value=max_len), min_size=min_dims, max_size=max_dims
372-
)
373-
374-
375-
# def zarr_compressors():
376-
# """A strategy for generating Zarr compressors."""
377-
# return st.sampled_from([None, Blosc(), GZip(), Zstd(), LZ4()])
378-
379-
380-
# def zarr_codecs():
381-
# """A strategy for generating Zarr codecs."""
382-
# return st.sampled_from([BytesCodec(), Blosc(), GZip(), Zstd(), LZ4()])
383-
384-
385-
def zarr_filters():
386-
"""A strategy for generating Zarr filters."""
387-
return st.lists(
388-
st.just(numcodecs.Delta(dtype="i4")), min_size=0, max_size=2
389-
) # Example filter, expand as needed
390-
391-
392-
def zarr_storage_transformers():
393-
"""A strategy for generating Zarr storage transformers."""
394-
return st.lists(
395-
st.dictionaries(
396-
simple_text(), st.one_of(st.integers(), st.floats(), st.booleans(), simple_text())
397-
),
398-
min_size=0,
399-
max_size=2,
400-
)
401-
402-
403-
@st.composite
404-
def array_metadata_v2(draw: st.DrawFn) -> ArrayV2Metadata:
405-
"""Generates valid ArrayV2Metadata objects for property-based testing."""
406-
dims = draw(st.integers(min_value=1, max_value=3)) # Limit dimensions for complexity
407-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
408-
max_chunk_len = max(shape) if shape else 100
409-
chunks = tuple(
410-
draw(
411-
st.lists(
412-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
413-
)
414-
)
415-
)
416-
417-
# Validate shape and chunks relationship
418-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False))) # Chunk size must be <= shape
419-
420-
dtype = draw(v2_dtypes())
421-
fill_value = draw(st.one_of([st.none(), npst.from_dtype(dtype)]))
422-
order = draw(st.sampled_from(["C", "F"]))
423-
dimension_separator = draw(st.sampled_from([".", "/"]))
424-
# compressor = draw(zarr_compressors())
425-
filters = tuple(draw(zarr_filters())) if draw(st.booleans()) else None
426-
attributes = draw(simple_attrs())
427-
428-
# Construct the metadata object. Type hints are crucial here for correctness.
429-
return ArrayV2Metadata(
430-
shape=shape,
431-
dtype=dtype,
432-
chunks=chunks,
433-
fill_value=fill_value,
434-
order=order,
435-
dimension_separator=dimension_separator,
436-
# compressor=compressor,
437-
filters=filters,
438-
attributes=attributes,
439-
)
440-
441-
442-
@st.composite
443-
def array_metadata_v3(draw: st.DrawFn) -> ArrayV3Metadata:
444-
"""Generates valid ArrayV3Metadata objects for property-based testing."""
445-
dims = draw(st.integers(min_value=1, max_value=3))
446-
shape = tuple(draw(array_shapes(min_dims=dims, max_dims=dims, max_len=100)))
447-
max_chunk_len = max(shape) if shape else 100
448-
chunks = tuple(
449-
draw(
450-
st.lists(
451-
st.integers(min_value=1, max_value=max_chunk_len), min_size=dims, max_size=dims
452-
)
453-
)
454-
)
455-
assume(all(c <= s for s, c in zip(shape, chunks, strict=False)))
456-
457-
dtype = draw(v3_dtypes())
458-
fill_value = draw(npst.from_dtype(dtype))
459-
chunk_grid = RegularChunkGrid(chunks) # Ensure chunks is passed as tuple.
460-
chunk_key_encoding = DefaultChunkKeyEncoding(separator="/") # Or st.sampled_from(["/", "."])
461-
# codecs = tuple(draw(st.lists(zarr_codecs(), min_size=0, max_size=3)))
462-
attributes = draw(simple_attrs())
463-
dimension_names = (
464-
tuple(draw(st.lists(st.one_of(st.none(), simple_text()), min_size=dims, max_size=dims)))
465-
if draw(st.booleans())
466-
else None
467-
)
468-
storage_transformers = tuple(draw(zarr_storage_transformers()))
469-
470-
return ArrayV3Metadata(
471-
shape=shape,
472-
data_type=dtype,
473-
chunk_grid=chunk_grid,
474-
chunk_key_encoding=chunk_key_encoding,
475-
fill_value=fill_value,
476-
# codecs=codecs,
477-
attributes=attributes,
478-
dimension_names=dimension_names,
479-
storage_transformers=storage_transformers,
480-
)

tests/test_properties.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
from hypothesis import assume, given
1515

1616
from zarr.abc.store import Store
17-
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON
17+
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
1818
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
1919
from zarr.core.sync import sync
2020
from zarr.testing.strategies import (
2121
array_metadata,
22-
array_metadata_v2,
2322
arrays,
2423
basic_indices,
2524
numpy_arrays,
@@ -30,14 +29,12 @@
3029

3130

3231
def deep_equal(a, b):
33-
"""Deep equality check w/ NaN e to handle array metadata serialization and deserialization behaviors"""
32+
"""Deep equality check with handling of special cases for array metadata classes"""
3433
if isinstance(a, (complex, np.complexfloating)) and isinstance(
3534
b, (complex, np.complexfloating)
3635
):
37-
# Convert to Python float to force standard NaN handling.
3836
a_real, a_imag = float(a.real), float(a.imag)
3937
b_real, b_imag = float(b.real), float(b.imag)
40-
# If both parts are NaN, consider them equal.
4138
if np.isnan(a_real) and np.isnan(b_real):
4239
real_eq = True
4340
else:
@@ -48,43 +45,36 @@ def deep_equal(a, b):
4845
imag_eq = a_imag == b_imag
4946
return real_eq and imag_eq
5047

51-
# Handle floats (including numpy floating types) and treat NaNs as equal.
5248
if isinstance(a, (float, np.floating)) and isinstance(b, (float, np.floating)):
5349
if np.isnan(a) and np.isnan(b):
5450
return True
5551
return a == b
5652

57-
# Handle numpy.datetime64 values, treating NaT as equal.
5853
if isinstance(a, np.datetime64) and isinstance(b, np.datetime64):
5954
if np.isnat(a) and np.isnat(b):
6055
return True
6156
return a == b
6257

63-
# Handle numpy arrays.
6458
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
6559
if a.shape != b.shape:
6660
return False
67-
# Compare elementwise.
6861
return all(deep_equal(x, y) for x, y in zip(a.flat, b.flat, strict=False))
6962

70-
# Handle dictionaries.
7163
if isinstance(a, dict) and isinstance(b, dict):
7264
if set(a.keys()) != set(b.keys()):
7365
return False
7466
return all(deep_equal(a[k], b[k]) for k in a)
7567

76-
# Handle lists and tuples.
7768
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
7869
if len(a) != len(b):
7970
return False
8071
return all(deep_equal(x, y) for x, y in zip(a, b, strict=False))
8172

82-
# Fallback to default equality.
8373
return a == b
8474

8575

8676
@given(data=st.data(), zarr_format=zarr_formats)
87-
def test_roundtrip(data: st.DataObject, zarr_format: int) -> None:
77+
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
8878
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
8979
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
9080
assert_array_equal(nparray, zarray[:])
@@ -197,18 +187,21 @@ async def test_roundtrip_array_metadata(
197187
# assert_array_equal(nparray, zarray[:])
198188

199189

200-
@given(array_metadata_v2())
201-
def test_v2meta_roundtrip(metadata):
190+
@given(data=st.data(), zarr_format=zarr_formats)
191+
def test_meta_roundtrip(data: st.DataObject, zarr_format: int) -> None:
192+
metadata = data.draw(array_metadata(zarr_formats=st.just(zarr_format)))
202193
buffer_dict = metadata.to_buffer_dict(prototype=default_buffer_prototype())
203-
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
204-
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
205194

206-
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
207-
zarray_dict["attributes"] = zattrs_dict
208-
209-
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
195+
if zarr_format == 2:
196+
zarray_dict = json.loads(buffer_dict[ZARRAY_JSON].to_bytes().decode())
197+
zattrs_dict = json.loads(buffer_dict[ZATTRS_JSON].to_bytes().decode())
198+
# zattrs and zarray are separate in v2, we have to add attributes back prior to `from_dict`
199+
zarray_dict["attributes"] = zattrs_dict
200+
metadata_roundtripped = ArrayV2Metadata.from_dict(zarray_dict)
201+
else:
202+
zarray_dict = json.loads(buffer_dict[ZARR_JSON].to_bytes().decode())
203+
metadata_roundtripped = ArrayV3Metadata.from_dict(zarray_dict)
210204

211-
# Convert both metadata instances to dictionaries.
212205
orig = dataclasses.asdict(metadata)
213206
rt = dataclasses.asdict(metadata_roundtripped)
214207

0 commit comments

Comments
 (0)