|
2 | 2 |
|
3 | 3 | import json
|
4 | 4 | from abc import ABC, abstractmethod
|
5 |
| -from collections.abc import Iterable |
| 5 | +from collections.abc import Iterable, Sequence |
6 | 6 | from dataclasses import dataclass, field, replace
|
7 | 7 | from enum import Enum
|
8 |
| -from typing import TYPE_CHECKING, Any, Literal |
| 8 | +from typing import TYPE_CHECKING, Any, Literal, cast, overload |
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import numpy.typing as npt
|
|
32 | 32 | ChunkCoords,
|
33 | 33 | ZarrFormat,
|
34 | 34 | parse_dtype,
|
35 |
| - parse_fill_value, |
36 | 35 | parse_named_configuration,
|
37 | 36 | parse_shapelike,
|
38 | 37 | )
|
@@ -189,7 +188,7 @@ def __init__(
|
189 | 188 | chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
|
190 | 189 | chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
|
191 | 190 | dimension_names_parsed = parse_dimension_names(dimension_names)
|
192 |
| - fill_value_parsed = parse_fill_value(fill_value) |
| 191 | + fill_value_parsed = parse_fill_value_v3(fill_value, dtype=data_type_parsed) |
193 | 192 | attributes_parsed = parse_attributes(attributes)
|
194 | 193 | codecs_parsed_partial = parse_codecs(codecs)
|
195 | 194 |
|
@@ -255,9 +254,18 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
|
255 | 254 | return self.chunk_key_encoding.encode_chunk_key(chunk_coords)
|
256 | 255 |
|
257 | 256 | def to_buffer_dict(self) -> dict[str, Buffer]:
|
258 |
| - def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]: |
| 257 | + def _json_convert(o: Any) -> Any: |
259 | 258 | if isinstance(o, np.dtype):
|
260 | 259 | return str(o)
|
| 260 | + if np.isscalar(o): |
| 261 | + # convert numpy scalar to python type, and pass |
| 262 | + # python types through |
| 263 | + out = getattr(o, "item", lambda: o)() |
| 264 | + if isinstance(out, complex): |
| 265 | + # python complex types are not JSON serializable, so we use the |
| 266 | + # serialization defined in the zarr v3 spec |
| 267 | + return [out.real, out.imag] |
| 268 | + return out |
261 | 269 | if isinstance(o, Enum):
|
262 | 270 | return o.name
|
263 | 271 | # this serializes numcodecs compressors
|
@@ -341,7 +349,7 @@ def __init__(
|
341 | 349 | order_parsed = parse_indexing_order(order)
|
342 | 350 | dimension_separator_parsed = parse_separator(dimension_separator)
|
343 | 351 | filters_parsed = parse_filters(filters)
|
344 |
| - fill_value_parsed = parse_fill_value(fill_value) |
| 352 | + fill_value_parsed = parse_fill_value_v2(fill_value, dtype=data_type_parsed) |
345 | 353 | attributes_parsed = parse_attributes(attributes)
|
346 | 354 |
|
347 | 355 | object.__setattr__(self, "shape", shape_parsed)
|
@@ -371,13 +379,17 @@ def chunks(self) -> ChunkCoords:
|
371 | 379 |
|
372 | 380 | def to_buffer_dict(self) -> dict[str, Buffer]:
|
373 | 381 | def _json_convert(
|
374 |
| - o: np.dtype[Any], |
375 |
| - ) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]: |
| 382 | + o: Any, |
| 383 | + ) -> Any: |
376 | 384 | if isinstance(o, np.dtype):
|
377 | 385 | if o.fields is None:
|
378 | 386 | return o.str
|
379 | 387 | else:
|
380 | 388 | return o.descr
|
| 389 | + if np.isscalar(o): |
| 390 | + # convert numpy scalar to python type, and pass |
| 391 | + # python types through |
| 392 | + return getattr(o, "item", lambda: o)() |
381 | 393 | raise TypeError
|
382 | 394 |
|
383 | 395 | zarray_dict = self.to_dict()
|
@@ -517,3 +529,105 @@ def parse_codecs(data: Iterable[Codec | dict[str, JSON]]) -> tuple[Codec, ...]:
|
517 | 529 | out += (get_codec_class(name_parsed).from_dict(c),)
|
518 | 530 |
|
519 | 531 | return out
|
| 532 | + |
| 533 | + |
| 534 | +def parse_fill_value_v2(fill_value: Any, dtype: np.dtype[Any]) -> Any: |
| 535 | + """ |
| 536 | + Parse a potential fill value into a value that is compatible with the provided dtype. |
| 537 | +
|
| 538 | + This is a light wrapper around zarr.v2.util.normalize_fill_value. |
| 539 | +
|
| 540 | + Parameters |
| 541 | + ---------- |
| 542 | + fill_value: Any |
| 543 | + A potential fill value. |
| 544 | + dtype: np.dtype[Any] |
| 545 | + A numpy dtype. |
| 546 | +
|
| 547 | + Returns |
| 548 | + An instance of `dtype`, or `None`, or any python object (in the case of an object dtype) |
| 549 | + """ |
| 550 | + from zarr.v2.util import normalize_fill_value |
| 551 | + |
| 552 | + return normalize_fill_value(fill_value=fill_value, dtype=dtype) |
| 553 | + |
| 554 | + |
| 555 | +BOOL = np.bool_ |
| 556 | +BOOL_DTYPE = np.dtypes.BoolDType |
| 557 | + |
| 558 | +INTEGER_DTYPE = ( |
| 559 | + np.dtypes.Int8DType |
| 560 | + | np.dtypes.Int16DType |
| 561 | + | np.dtypes.Int32DType |
| 562 | + | np.dtypes.Int64DType |
| 563 | + | np.dtypes.UByteDType |
| 564 | + | np.dtypes.UInt16DType |
| 565 | + | np.dtypes.UInt32DType |
| 566 | + | np.dtypes.UInt64DType |
| 567 | +) |
| 568 | + |
| 569 | +INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 |
| 570 | +FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType |
| 571 | +FLOAT = np.float16 | np.float32 | np.float64 |
| 572 | +COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType |
| 573 | +COMPLEX = np.complex64 | np.complex128 |
| 574 | +# todo: r* dtypes |
| 575 | + |
| 576 | + |
| 577 | +@overload |
| 578 | +def parse_fill_value_v3(fill_value: Any, dtype: BOOL_DTYPE) -> BOOL: ... |
| 579 | + |
| 580 | + |
| 581 | +@overload |
| 582 | +def parse_fill_value_v3(fill_value: Any, dtype: INTEGER_DTYPE) -> INTEGER: ... |
| 583 | + |
| 584 | + |
| 585 | +@overload |
| 586 | +def parse_fill_value_v3(fill_value: Any, dtype: FLOAT_DTYPE) -> FLOAT: ... |
| 587 | + |
| 588 | + |
| 589 | +@overload |
| 590 | +def parse_fill_value_v3(fill_value: Any, dtype: COMPLEX_DTYPE) -> COMPLEX: ... |
| 591 | + |
| 592 | + |
| 593 | +def parse_fill_value_v3( |
| 594 | + fill_value: Any, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE |
| 595 | +) -> BOOL | INTEGER | FLOAT | COMPLEX: |
| 596 | + """ |
| 597 | + Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. |
| 598 | + If `fill_value` is `None`, then this function will return the result of casting the value 0 |
| 599 | + to the provided data type. Otherwise, `fill_value` will be cast to the provided data type. |
| 600 | +
|
| 601 | + Note that some numpy dtypes use very permissive casting rules. For example, |
| 602 | + `np.bool_({'not remotely a bool'})` returns `True`. Thus this function should not be used for |
| 603 | + validating that the provided fill value is a valid instance of the data type. |
| 604 | +
|
| 605 | + Parameters |
| 606 | + ---------- |
| 607 | + fill_value: Any |
| 608 | + A potential fill value. |
| 609 | + dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE |
| 610 | + A numpy data type that models a data type defined in the Zarr V3 specification. |
| 611 | +
|
| 612 | + Returns |
| 613 | + ------- |
| 614 | + A scalar instance of `dtype` |
| 615 | + """ |
| 616 | + if fill_value is None: |
| 617 | + return dtype.type(0) |
| 618 | + if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): |
| 619 | + if dtype in (np.complex64, np.complex128): |
| 620 | + dtype = cast(COMPLEX_DTYPE, dtype) |
| 621 | + if len(fill_value) == 2: |
| 622 | + # complex datatypes serialize to JSON arrays with two elements |
| 623 | + return dtype.type(complex(*fill_value)) |
| 624 | + else: |
| 625 | + msg = ( |
| 626 | + f"Got an invalid fill value for complex data type {dtype}." |
| 627 | + f"Expected a sequence with 2 elements, but {fill_value} has " |
| 628 | + f"length {len(fill_value)}." |
| 629 | + ) |
| 630 | + raise ValueError(msg) |
| 631 | + msg = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}." |
| 632 | + raise TypeError(msg) |
| 633 | + return dtype.type(fill_value) |
0 commit comments