1
1
from __future__ import annotations
2
2
from enum import Enum
3
- from typing import TYPE_CHECKING , Iterable , Mapping , NamedTuple , Union
3
+ from typing import TYPE_CHECKING , Any , Iterable , Mapping , NamedTuple , Union , Optional
4
4
from dataclasses import dataclass , replace
5
5
from functools import lru_cache
6
6
7
7
8
8
import numpy as np
9
+ import numpy .typing as npt
9
10
from zarr .abc .codec import (
10
11
Codec ,
11
12
ArrayBytesCodec ,
18
19
from zarr .codecs .registry import register_codec
19
20
from zarr .common import (
20
21
ArraySpec ,
22
+ BytesLike ,
21
23
ChunkCoordsLike ,
24
+ ChunkCoords ,
22
25
concurrent_map ,
23
26
parse_enum ,
24
27
parse_named_configuration ,
39
42
)
40
43
41
44
if TYPE_CHECKING :
42
- from typing import Awaitable , Callable , Dict , Iterator , List , Optional , Set , Tuple
45
+ from typing import Awaitable , Callable , Dict , Iterator , List , Set , Tuple
43
46
from typing_extensions import Self
44
47
45
48
from zarr .store import StorePath
46
49
from zarr .common import (
47
50
JSON ,
48
- ChunkCoords ,
49
- BytesLike ,
50
51
SliceSelection ,
51
52
)
52
53
from zarr .config import RuntimeConfiguration
@@ -65,7 +66,7 @@ def parse_index_location(data: JSON) -> ShardingCodecIndexLocation:
65
66
66
67
class _ShardIndex (NamedTuple ):
67
68
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
68
- offsets_and_lengths : np .ndarray
69
+ offsets_and_lengths : npt . NDArray [ np .uint64 ]
69
70
70
71
@property
71
72
def chunks_per_shard (self ) -> ChunkCoords :
@@ -126,7 +127,10 @@ def create_empty(cls, chunks_per_shard: ChunkCoords) -> _ShardIndex:
126
127
return cls (offsets_and_lengths )
127
128
128
129
129
- class _ShardProxy (Mapping ):
130
+ _ShardMapping = Mapping [ChunkCoords , Optional [BytesLike ]]
131
+
132
+
133
+ class _ShardProxy (_ShardMapping ):
130
134
index : _ShardIndex
131
135
buf : BytesLike
132
136
@@ -175,7 +179,7 @@ def merge_with_morton_order(
175
179
cls ,
176
180
chunks_per_shard : ChunkCoords ,
177
181
tombstones : Set [ChunkCoords ],
178
- * shard_dicts : Mapping [ ChunkCoords , BytesLike ] ,
182
+ * shard_dicts : _ShardMapping ,
179
183
) -> _ShardBuilder :
180
184
obj = cls .create_empty (chunks_per_shard )
181
185
for chunk_coords in morton_order_iter (chunks_per_shard ):
@@ -303,7 +307,7 @@ async def decode(
303
307
shard_bytes : BytesLike ,
304
308
shard_spec : ArraySpec ,
305
309
runtime_configuration : RuntimeConfiguration ,
306
- ) -> np . ndarray :
310
+ ) -> npt . NDArray [ Any ] :
307
311
# print("decode")
308
312
shard_shape = shard_spec .shape
309
313
chunk_shape = self .chunk_shape
@@ -353,7 +357,7 @@ async def decode_partial(
353
357
selection : SliceSelection ,
354
358
shard_spec : ArraySpec ,
355
359
runtime_configuration : RuntimeConfiguration ,
356
- ) -> Optional [np . ndarray ]:
360
+ ) -> Optional [npt . NDArray [ Any ] ]:
357
361
shard_shape = shard_spec .shape
358
362
chunk_shape = self .chunk_shape
359
363
chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
@@ -375,7 +379,7 @@ async def decode_partial(
375
379
all_chunk_coords = set (chunk_coords for chunk_coords , _ , _ in indexed_chunks )
376
380
377
381
# reading bytes of all requested chunks
378
- shard_dict : Mapping [ ChunkCoords , BytesLike ] = {}
382
+ shard_dict : _ShardMapping = {}
379
383
if self ._is_total_shard (all_chunk_coords , chunks_per_shard ):
380
384
# read entire shard
381
385
shard_dict_maybe = await self ._load_full_shard_maybe (store_path , chunks_per_shard )
@@ -423,7 +427,7 @@ async def _read_chunk(
423
427
out_selection : SliceSelection ,
424
428
shard_spec : ArraySpec ,
425
429
runtime_configuration : RuntimeConfiguration ,
426
- out : np . ndarray ,
430
+ out : npt . NDArray [ Any ] ,
427
431
) -> None :
428
432
chunk_spec = self ._get_chunk_spec (shard_spec )
429
433
chunk_bytes = shard_dict .get (chunk_coords , None )
@@ -436,7 +440,7 @@ async def _read_chunk(
436
440
437
441
async def encode (
438
442
self ,
439
- shard_array : np . ndarray ,
443
+ shard_array : npt . NDArray [ Any ] ,
440
444
shard_spec : ArraySpec ,
441
445
runtime_configuration : RuntimeConfiguration ,
442
446
) -> Optional [BytesLike ]:
@@ -453,7 +457,7 @@ async def encode(
453
457
)
454
458
455
459
async def _write_chunk (
456
- shard_array : np . ndarray ,
460
+ shard_array : npt . NDArray [ Any ] ,
457
461
chunk_coords : ChunkCoords ,
458
462
chunk_selection : SliceSelection ,
459
463
out_selection : SliceSelection ,
@@ -498,7 +502,7 @@ async def _write_chunk(
498
502
async def encode_partial (
499
503
self ,
500
504
store_path : StorePath ,
501
- shard_array : np . ndarray ,
505
+ shard_array : npt . NDArray [ Any ] ,
502
506
selection : SliceSelection ,
503
507
shard_spec : ArraySpec ,
504
508
runtime_configuration : RuntimeConfiguration ,
0 commit comments