Skip to content

Commit b8baa68

Browse files
authored
Cast fill value to array's dtype (#2020)
* add fill value parsing routines and tests * add fill_value attribute to array, and test that it works as expected for v3 arrays * Update tests/v3/test_metadata/test_v3.py * clean up docstrings
1 parent 33b1589 commit b8baa68

File tree

6 files changed

+354
-68
lines changed

6 files changed

+354
-68
lines changed

src/zarr/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,10 @@ def order(self) -> Literal["C", "F"]:
715715
def read_only(self) -> bool:
716716
return self._async_array.read_only
717717

718+
@property
719+
def fill_value(self) -> Any:
720+
return self.metadata.fill_value
721+
718722
def __array__(
719723
self, dtype: npt.DTypeLike | None = None, copy: bool | None = None
720724
) -> NDArrayLike:

src/zarr/metadata.py

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import json
44
from abc import ABC, abstractmethod
5-
from collections.abc import Iterable
5+
from collections.abc import Iterable, Sequence
66
from dataclasses import dataclass, field, replace
77
from enum import Enum
8-
from typing import TYPE_CHECKING, Any, Literal
8+
from typing import TYPE_CHECKING, Any, Literal, cast, overload
99

1010
import numpy as np
1111
import numpy.typing as npt
@@ -32,7 +32,6 @@
3232
ChunkCoords,
3333
ZarrFormat,
3434
parse_dtype,
35-
parse_fill_value,
3635
parse_named_configuration,
3736
parse_shapelike,
3837
)
@@ -189,7 +188,7 @@ def __init__(
189188
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
190189
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
191190
dimension_names_parsed = parse_dimension_names(dimension_names)
192-
fill_value_parsed = parse_fill_value(fill_value)
191+
fill_value_parsed = parse_fill_value_v3(fill_value, dtype=data_type_parsed)
193192
attributes_parsed = parse_attributes(attributes)
194193
codecs_parsed_partial = parse_codecs(codecs)
195194

@@ -255,9 +254,18 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
255254
return self.chunk_key_encoding.encode_chunk_key(chunk_coords)
256255

257256
def to_buffer_dict(self) -> dict[str, Buffer]:
258-
def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
257+
def _json_convert(o: Any) -> Any:
259258
if isinstance(o, np.dtype):
260259
return str(o)
260+
if np.isscalar(o):
261+
# convert numpy scalar to python type, and pass
262+
# python types through
263+
out = getattr(o, "item", lambda: o)()
264+
if isinstance(out, complex):
265+
# python complex types are not JSON serializable, so we use the
266+
# serialization defined in the zarr v3 spec
267+
return [out.real, out.imag]
268+
return out
261269
if isinstance(o, Enum):
262270
return o.name
263271
# this serializes numcodecs compressors
@@ -341,7 +349,7 @@ def __init__(
341349
order_parsed = parse_indexing_order(order)
342350
dimension_separator_parsed = parse_separator(dimension_separator)
343351
filters_parsed = parse_filters(filters)
344-
fill_value_parsed = parse_fill_value(fill_value)
352+
fill_value_parsed = parse_fill_value_v2(fill_value, dtype=data_type_parsed)
345353
attributes_parsed = parse_attributes(attributes)
346354

347355
object.__setattr__(self, "shape", shape_parsed)
@@ -371,13 +379,17 @@ def chunks(self) -> ChunkCoords:
371379

372380
def to_buffer_dict(self) -> dict[str, Buffer]:
373381
def _json_convert(
374-
o: np.dtype[Any],
375-
) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]:
382+
o: Any,
383+
) -> Any:
376384
if isinstance(o, np.dtype):
377385
if o.fields is None:
378386
return o.str
379387
else:
380388
return o.descr
389+
if np.isscalar(o):
390+
# convert numpy scalar to python type, and pass
391+
# python types through
392+
return getattr(o, "item", lambda: o)()
381393
raise TypeError
382394

383395
zarray_dict = self.to_dict()
@@ -517,3 +529,105 @@ def parse_codecs(data: Iterable[Codec | dict[str, JSON]]) -> tuple[Codec, ...]:
517529
out += (get_codec_class(name_parsed).from_dict(c),)
518530

519531
return out
532+
533+
534+
def parse_fill_value_v2(fill_value: Any, dtype: np.dtype[Any]) -> Any:
535+
"""
536+
Parse a potential fill value into a value that is compatible with the provided dtype.
537+
538+
This is a light wrapper around zarr.v2.util.normalize_fill_value.
539+
540+
Parameters
541+
----------
542+
fill_value: Any
543+
A potential fill value.
544+
dtype: np.dtype[Any]
545+
A numpy dtype.
546+
547+
Returns
548+
An instance of `dtype`, or `None`, or any python object (in the case of an object dtype)
549+
"""
550+
from zarr.v2.util import normalize_fill_value
551+
552+
return normalize_fill_value(fill_value=fill_value, dtype=dtype)
553+
554+
555+
BOOL = np.bool_
556+
BOOL_DTYPE = np.dtypes.BoolDType
557+
558+
INTEGER_DTYPE = (
559+
np.dtypes.Int8DType
560+
| np.dtypes.Int16DType
561+
| np.dtypes.Int32DType
562+
| np.dtypes.Int64DType
563+
| np.dtypes.UByteDType
564+
| np.dtypes.UInt16DType
565+
| np.dtypes.UInt32DType
566+
| np.dtypes.UInt64DType
567+
)
568+
569+
INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64
570+
FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType
571+
FLOAT = np.float16 | np.float32 | np.float64
572+
COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType
573+
COMPLEX = np.complex64 | np.complex128
574+
# todo: r* dtypes
575+
576+
577+
@overload
578+
def parse_fill_value_v3(fill_value: Any, dtype: BOOL_DTYPE) -> BOOL: ...
579+
580+
581+
@overload
582+
def parse_fill_value_v3(fill_value: Any, dtype: INTEGER_DTYPE) -> INTEGER: ...
583+
584+
585+
@overload
586+
def parse_fill_value_v3(fill_value: Any, dtype: FLOAT_DTYPE) -> FLOAT: ...
587+
588+
589+
@overload
590+
def parse_fill_value_v3(fill_value: Any, dtype: COMPLEX_DTYPE) -> COMPLEX: ...
591+
592+
593+
def parse_fill_value_v3(
594+
fill_value: Any, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE
595+
) -> BOOL | INTEGER | FLOAT | COMPLEX:
596+
"""
597+
Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type.
598+
If `fill_value` is `None`, then this function will return the result of casting the value 0
599+
to the provided data type. Otherwise, `fill_value` will be cast to the provided data type.
600+
601+
Note that some numpy dtypes use very permissive casting rules. For example,
602+
`np.bool_({'not remotely a bool'})` returns `True`. Thus this function should not be used for
603+
validating that the provided fill value is a valid instance of the data type.
604+
605+
Parameters
606+
----------
607+
fill_value: Any
608+
A potential fill value.
609+
dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE
610+
A numpy data type that models a data type defined in the Zarr V3 specification.
611+
612+
Returns
613+
-------
614+
A scalar instance of `dtype`
615+
"""
616+
if fill_value is None:
617+
return dtype.type(0)
618+
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
619+
if dtype in (np.complex64, np.complex128):
620+
dtype = cast(COMPLEX_DTYPE, dtype)
621+
if len(fill_value) == 2:
622+
# complex datatypes serialize to JSON arrays with two elements
623+
return dtype.type(complex(*fill_value))
624+
else:
625+
msg = (
626+
f"Got an invalid fill value for complex data type {dtype}."
627+
f"Expected a sequence with 2 elements, but {fill_value} has "
628+
f"length {len(fill_value)}."
629+
)
630+
raise ValueError(msg)
631+
msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}."
632+
raise TypeError(msg)
633+
return dtype.type(fill_value)

tests/v3/test_array.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23

34
from zarr.array import Array
@@ -34,3 +35,51 @@ def test_array_name_properties_with_group(
3435
assert spam.path == "bar/spam"
3536
assert spam.name == "/bar/spam"
3637
assert spam.basename == "spam"
38+
39+
40+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
41+
@pytest.mark.parametrize("specifiy_fill_value", [True, False])
42+
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "complex64"])
43+
def test_array_v3_fill_value_default(
44+
store: MemoryStore, specifiy_fill_value: bool, dtype_str: str
45+
) -> None:
46+
"""
47+
Test that creating an array with the fill_value parameter set to None, or unspecified,
48+
results in the expected fill_value attribute of the array, i.e. 0 cast to the array's dtype.
49+
"""
50+
shape = (10,)
51+
default_fill_value = 0
52+
if specifiy_fill_value:
53+
arr = Array.create(
54+
store=store,
55+
shape=shape,
56+
dtype=dtype_str,
57+
zarr_format=3,
58+
chunk_shape=shape,
59+
fill_value=None,
60+
)
61+
else:
62+
arr = Array.create(
63+
store=store, shape=shape, dtype=dtype_str, zarr_format=3, chunk_shape=shape
64+
)
65+
66+
assert arr.fill_value == np.dtype(dtype_str).type(default_fill_value)
67+
assert arr.fill_value.dtype == arr.dtype
68+
69+
70+
@pytest.mark.parametrize("store", ["memory"], indirect=True)
71+
@pytest.mark.parametrize("fill_value", [False, 0.0, 1, 2.3])
72+
@pytest.mark.parametrize("dtype_str", ["bool", "uint8", "float32", "complex64"])
73+
def test_array_v3_fill_value(store: MemoryStore, fill_value: int, dtype_str: str) -> None:
74+
shape = (10,)
75+
arr = Array.create(
76+
store=store,
77+
shape=shape,
78+
dtype=dtype_str,
79+
zarr_format=3,
80+
chunk_shape=shape,
81+
fill_value=fill_value,
82+
)
83+
84+
assert arr.fill_value == np.dtype(dtype_str).type(fill_value)
85+
assert arr.fill_value.dtype == arr.dtype

tests/v3/test_metadata.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +0,0 @@
1-
from __future__ import annotations
2-
3-
from typing import TYPE_CHECKING
4-
5-
import pytest
6-
7-
if TYPE_CHECKING:
8-
from collections.abc import Sequence
9-
from typing import Any
10-
11-
from zarr.metadata import parse_dimension_names, parse_zarr_format_v2, parse_zarr_format_v3
12-
13-
14-
# todo: test
15-
def test_datatype_enum(): ...
16-
17-
18-
# todo: test
19-
# this will almost certainly be a collection of tests
20-
def test_array_metadata_v3(): ...
21-
22-
23-
# todo: test
24-
# this will almost certainly be a collection of tests
25-
def test_array_metadata_v2(): ...
26-
27-
28-
@pytest.mark.parametrize("data", [None, ("a", "b", "c"), ["a", "a", "a"]])
29-
def parse_dimension_names_valid(data: Sequence[str] | None) -> None:
30-
assert parse_dimension_names(data) == data
31-
32-
33-
@pytest.mark.parametrize("data", [(), [1, 2, "a"], {"foo": 10}])
34-
def parse_dimension_names_invalid(data: Any) -> None:
35-
with pytest.raises(TypeError, match="Expected either None or iterable of str,"):
36-
parse_dimension_names(data)
37-
38-
39-
# todo: test
40-
def test_parse_attributes() -> None: ...
41-
42-
43-
def test_parse_zarr_format_v3_valid() -> None:
44-
assert parse_zarr_format_v3(3) == 3
45-
46-
47-
@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"])
48-
def test_parse_zarr_foramt_v3_invalid(data: Any) -> None:
49-
with pytest.raises(ValueError, match=f"Invalid value. Expected 3. Got {data}"):
50-
parse_zarr_format_v3(data)
51-
52-
53-
def test_parse_zarr_format_v2_valid() -> None:
54-
assert parse_zarr_format_v2(2) == 2
55-
56-
57-
@pytest.mark.parametrize("data", [None, 1, 3, 4, 5, "3"])
58-
def test_parse_zarr_foramt_v2_invalid(data: Any) -> None:
59-
with pytest.raises(ValueError, match=f"Invalid value. Expected 2. Got {data}"):
60-
parse_zarr_format_v2(data)

tests/v3/test_metadata/test_v2.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from typing import Any
7+
8+
import pytest
9+
10+
from zarr.metadata import parse_zarr_format_v2
11+
12+
13+
def test_parse_zarr_format_valid() -> None:
14+
assert parse_zarr_format_v2(2) == 2
15+
16+
17+
@pytest.mark.parametrize("data", [None, 1, 3, 4, 5, "3"])
18+
def test_parse_zarr_format_invalid(data: Any) -> None:
19+
with pytest.raises(ValueError, match=f"Invalid value. Expected 2. Got {data}"):
20+
parse_zarr_format_v2(data)

0 commit comments

Comments
 (0)