Skip to content

Commit 549cf28

Browse files
authored
Finish typing zarr.metadata (#1880)
1 parent 55f4913 commit 549cf28

File tree

6 files changed

+32
-27
lines changed

6 files changed

+32
-27
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ module = [
239239
"zarr.array_v2",
240240
"zarr.array",
241241
"zarr.group",
242-
"zarr.metadata"
243242
]
244243
disallow_untyped_defs = false
245244

src/zarr/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ async def _create_v3(
199199

200200
if chunk_key_encoding is None:
201201
chunk_key_encoding = ("default", "/")
202+
assert chunk_key_encoding is not None
203+
202204
if isinstance(chunk_key_encoding, tuple):
203205
chunk_key_encoding = (
204206
V2ChunkKeyEncoding(separator=chunk_key_encoding[1])

src/zarr/chunk_grids.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import itertools
44
from collections.abc import Iterator
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING
77

88
from zarr.abc.metadata import Metadata
99
from zarr.common import (
@@ -22,13 +22,13 @@
2222
@dataclass(frozen=True)
2323
class ChunkGrid(Metadata):
2424
@classmethod
25-
def from_dict(cls, data: dict[str, JSON]) -> ChunkGrid:
25+
def from_dict(cls, data: dict[str, JSON] | ChunkGrid) -> ChunkGrid:
2626
if isinstance(data, ChunkGrid):
2727
return data
2828

2929
name_parsed, _ = parse_named_configuration(data)
3030
if name_parsed == "regular":
31-
return RegularChunkGrid.from_dict(data)
31+
return RegularChunkGrid._from_dict(data)
3232
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.")
3333

3434
def all_chunk_coords(self, array_shape: ChunkCoords) -> Iterator[ChunkCoords]:
@@ -45,7 +45,7 @@ def __init__(self, *, chunk_shape: ChunkCoordsLike) -> None:
4545
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
4646

4747
@classmethod
48-
def from_dict(cls, data: dict[str, Any]) -> Self:
48+
def _from_dict(cls, data: dict[str, JSON]) -> Self:
4949
_, configuration_parsed = parse_named_configuration(data, "regular")
5050

5151
return cls(**configuration_parsed) # type: ignore[arg-type]

src/zarr/chunk_key_encodings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, *, separator: SeparatorLiteral) -> None:
3434
object.__setattr__(self, "separator", separator_parsed)
3535

3636
@classmethod
37-
def from_dict(cls, data: dict[str, JSON]) -> ChunkKeyEncoding:
37+
def from_dict(cls, data: dict[str, JSON] | ChunkKeyEncoding) -> ChunkKeyEncoding:
3838
if isinstance(data, ChunkKeyEncoding):
3939
return data
4040

src/zarr/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Awaitable, Callable, Iterator
1414

1515
import numpy as np
16+
import numpy.typing as npt
1617

1718
ZARR_JSON = "zarr.json"
1819
ZARRAY_JSON = ".zarray"
@@ -150,7 +151,7 @@ def parse_named_configuration(
150151
return name_parsed, configuration_parsed
151152

152153

153-
def parse_shapelike(data: Any) -> tuple[int, ...]:
154+
def parse_shapelike(data: Iterable[int]) -> tuple[int, ...]:
154155
if not isinstance(data, Iterable):
155156
raise TypeError(f"Expected an iterable. Got {data} instead.")
156157
data_tuple = tuple(data)
@@ -164,7 +165,7 @@ def parse_shapelike(data: Any) -> tuple[int, ...]:
164165
return data_tuple
165166

166167

167-
def parse_dtype(data: Any) -> np.dtype[Any]:
168+
def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
168169
# todo: real validation
169170
return np.dtype(data)
170171

src/zarr/metadata.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from typing_extensions import Self
2424

25+
import numcodecs.abc
2526

2627
from zarr.common import (
2728
JSON,
@@ -168,15 +169,15 @@ class ArrayV3Metadata(ArrayMetadata):
168169
def __init__(
169170
self,
170171
*,
171-
shape,
172-
data_type,
173-
chunk_grid,
174-
chunk_key_encoding,
175-
fill_value,
176-
codecs,
177-
attributes,
178-
dimension_names,
179-
):
172+
shape: Iterable[int],
173+
data_type: npt.DTypeLike,
174+
chunk_grid: dict[str, JSON] | ChunkGrid,
175+
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
176+
fill_value: Any,
177+
codecs: Iterable[Codec | JSON],
178+
attributes: None | dict[str, JSON],
179+
dimension_names: None | Iterable[str],
180+
) -> None:
180181
"""
181182
Because the class is a frozen dataclass, we set attributes using object.__setattr__
182183
"""
@@ -249,14 +250,14 @@ def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str:
249250
return self.chunk_key_encoding.encode_chunk_key(chunk_coords)
250251

251252
def to_buffer_dict(self) -> dict[str, Buffer]:
252-
def _json_convert(o):
253+
def _json_convert(o: np.dtype[Any] | Enum | Codec) -> str | dict[str, Any]:
253254
if isinstance(o, np.dtype):
254255
return str(o)
255256
if isinstance(o, Enum):
256257
return o.name
257258
# this serializes numcodecs compressors
258259
# todo: implement to_dict for codecs
259-
elif hasattr(o, "get_config"):
260+
elif isinstance(o, numcodecs.abc.Codec):
260261
return o.get_config()
261262
raise TypeError
262263

@@ -271,9 +272,10 @@ def from_dict(cls, data: dict[str, JSON]) -> ArrayV3Metadata:
271272
# check that the node_type attribute is correct
272273
_ = parse_node_type_array(data.pop("node_type"))
273274

274-
dimension_names = data.pop("dimension_names", None)
275+
data["dimension_names"] = data.pop("dimension_names", None)
275276

276-
return cls(**data, dimension_names=dimension_names)
277+
# TODO: Remove the ignores and use a TypedDict to type `data`
278+
return cls(**data) # type: ignore[arg-type]
277279

278280
def to_dict(self) -> dict[str, Any]:
279281
out_dict = super().to_dict()
@@ -367,7 +369,9 @@ def codec_pipeline(self) -> CodecPipeline:
367369
)
368370

369371
def to_buffer_dict(self) -> dict[str, Buffer]:
370-
def _json_convert(o):
372+
def _json_convert(
373+
o: np.dtype[Any],
374+
) -> str | list[tuple[str, str] | tuple[str, str, tuple[int, ...]]]:
371375
if isinstance(o, np.dtype):
372376
if o.fields is None:
373377
return o.str
@@ -399,7 +403,7 @@ def to_dict(self) -> JSON:
399403
zarray_dict["chunks"] = self.chunk_grid.chunk_shape
400404

401405
_ = zarray_dict.pop("data_type")
402-
zarray_dict["dtype"] = self.data_type
406+
zarray_dict["dtype"] = self.data_type.str
403407

404408
return zarray_dict
405409

@@ -422,7 +426,7 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
422426
return replace(self, attributes=attributes)
423427

424428

425-
def parse_dimension_names(data: Any) -> tuple[str, ...] | None:
429+
def parse_dimension_names(data: None | Iterable[str]) -> tuple[str, ...] | None:
426430
if data is None:
427431
return data
428432
if isinstance(data, Iterable) and all([isinstance(x, str) for x in data]):
@@ -432,12 +436,11 @@ def parse_dimension_names(data: Any) -> tuple[str, ...] | None:
432436

433437

434438
# todo: real validation
435-
def parse_attributes(data: Any) -> dict[str, JSON]:
439+
def parse_attributes(data: None | dict[str, JSON]) -> dict[str, JSON]:
436440
if data is None:
437441
return {}
438442

439-
data_json = cast(dict[str, JSON], data)
440-
return data_json
443+
return data
441444

442445

443446
# todo: move to its own module and drop _v3 suffix

0 commit comments

Comments
 (0)