16
16
import numpy as np
17
17
from attr import evolve , frozen
18
18
19
- from zarr .v3 .abc .codec import ArrayBytesCodecPartialDecodeMixin
20
-
21
19
22
20
# from zarr.v3.array_v2 import ArrayV2
23
21
from zarr .v3 .codecs import CodecMetadata , CodecPipeline , bytes_codec
22
+ from zarr .v3 .codecs .registry import get_codec_from_metadata
24
23
from zarr .v3 .common import (
25
24
ZARR_JSON ,
26
25
ChunkCoords ,
31
30
from zarr .v3 .indexing import BasicIndexer , all_chunk_coords , is_total_slice
32
31
from zarr .v3 .metadata import (
33
32
ArrayMetadata ,
33
+ ArraySpec ,
34
34
DataType ,
35
35
DefaultChunkKeyEncodingConfigurationMetadata ,
36
36
DefaultChunkKeyEncodingMetadata ,
41
41
V2ChunkKeyEncodingMetadata ,
42
42
dtype_to_data_type ,
43
43
)
44
- from zarr .v3 .codecs .sharding import ShardingCodec
45
44
from zarr .v3 .store import StoreLike , StorePath , make_store_path
46
45
from zarr .v3 .sync import sync
47
46
@@ -118,8 +117,11 @@ async def create(
118
117
metadata = metadata ,
119
118
store_path = store_path ,
120
119
runtime_configuration = runtime_configuration ,
121
- codec_pipeline = CodecPipeline .from_metadata (
122
- metadata .codecs , metadata .get_core_metadata (runtime_configuration )
120
+ codec_pipeline = CodecPipeline .create (
121
+ [
122
+ get_codec_from_metadata (codec ).evolve (ndim = len (shape ), data_type = data_type )
123
+ for codec in codecs
124
+ ]
123
125
),
124
126
)
125
127
@@ -134,13 +136,17 @@ def from_json(
134
136
runtime_configuration : RuntimeConfiguration ,
135
137
) -> AsyncArray :
136
138
metadata = ArrayMetadata .from_json (zarr_json )
139
+ codecs = [
140
+ get_codec_from_metadata (codec ).evolve (
141
+ ndim = len (metadata .shape ), data_type = metadata .data_type
142
+ )
143
+ for codec in metadata .codecs
144
+ ]
137
145
async_array = cls (
138
146
metadata = metadata ,
139
147
store_path = store_path ,
140
148
runtime_configuration = runtime_configuration ,
141
- codec_pipeline = CodecPipeline .from_metadata (
142
- metadata .codecs , metadata .get_core_metadata (runtime_configuration )
143
- ),
149
+ codec_pipeline = CodecPipeline .create (codecs ),
144
150
)
145
151
async_array ._validate_metadata ()
146
152
return async_array
@@ -240,6 +246,7 @@ def _validate_metadata(self) -> None:
240
246
self .metadata .dimension_names
241
247
), "`dimension_names` and `shape` need to have the same number of dimensions."
242
248
assert self .metadata .fill_value is not None , "`fill_value` is required."
249
+ self .codec_pipeline .validate (self .metadata )
243
250
244
251
async def _read_chunk (
245
252
self ,
@@ -248,15 +255,14 @@ async def _read_chunk(
248
255
out_selection : SliceSelection ,
249
256
out : np .ndarray ,
250
257
):
258
+ chunk_spec = self .metadata .get_chunk_spec (chunk_coords )
251
259
chunk_key_encoding = self .metadata .chunk_key_encoding
252
260
chunk_key = chunk_key_encoding .encode_chunk_key (chunk_coords )
253
261
store_path = self .store_path / chunk_key
254
262
255
- if len (self .codec_pipeline .codecs ) == 1 and isinstance (
256
- self .codec_pipeline .codecs [0 ], ArrayBytesCodecPartialDecodeMixin
257
- ):
258
- chunk_array = await self .codec_pipeline .codecs [0 ].decode_partial (
259
- store_path , chunk_selection
263
+ if self .codec_pipeline .supports_partial_decode :
264
+ chunk_array = await self .codec_pipeline .decode_partial (
265
+ store_path , chunk_selection , chunk_spec , self .runtime_configuration
260
266
)
261
267
if chunk_array is not None :
262
268
out [out_selection ] = chunk_array
@@ -265,7 +271,9 @@ async def _read_chunk(
265
271
else :
266
272
chunk_bytes = await store_path .get ()
267
273
if chunk_bytes is not None :
268
- chunk_array = await self .codec_pipeline .decode (chunk_bytes )
274
+ chunk_array = await self .codec_pipeline .decode (
275
+ chunk_bytes , chunk_spec , self .runtime_configuration
276
+ )
269
277
tmp = chunk_array [chunk_selection ]
270
278
out [out_selection ] = tmp
271
279
else :
@@ -316,6 +324,7 @@ async def _write_chunk(
316
324
chunk_selection : SliceSelection ,
317
325
out_selection : SliceSelection ,
318
326
):
327
+ chunk_spec = self .metadata .get_chunk_spec (chunk_coords )
319
328
chunk_key_encoding = self .metadata .chunk_key_encoding
320
329
chunk_key = chunk_key_encoding .encode_chunk_key (chunk_coords )
321
330
store_path = self .store_path / chunk_key
@@ -330,17 +339,16 @@ async def _write_chunk(
330
339
chunk_array .fill (value )
331
340
else :
332
341
chunk_array = value [out_selection ]
333
- await self ._write_chunk_to_store (store_path , chunk_array )
342
+ await self ._write_chunk_to_store (store_path , chunk_array , chunk_spec )
334
343
335
- elif len (self .codec_pipeline .codecs ) == 1 and isinstance (
336
- self .codec_pipeline .codecs [0 ], ShardingCodec
337
- ):
338
- sharding_codec = self .codec_pipeline .codecs [0 ]
344
+ elif self .codec_pipeline .supports_partial_encode :
339
345
# print("encode_partial", chunk_coords, chunk_selection, repr(self))
340
- await sharding_codec .encode_partial (
346
+ await self . codec_pipeline .encode_partial (
341
347
store_path ,
342
348
value [out_selection ],
343
349
chunk_selection ,
350
+ chunk_spec ,
351
+ self .runtime_configuration ,
344
352
)
345
353
else :
346
354
# writing partial chunks
@@ -356,18 +364,24 @@ async def _write_chunk(
356
364
chunk_array .fill (self .metadata .fill_value )
357
365
else :
358
366
chunk_array = (
359
- await self .codec_pipeline .decode (chunk_bytes )
367
+ await self .codec_pipeline .decode (
368
+ chunk_bytes , chunk_spec , self .runtime_configuration
369
+ )
360
370
).copy () # make a writable copy
361
371
chunk_array [chunk_selection ] = value [out_selection ]
362
372
363
- await self ._write_chunk_to_store (store_path , chunk_array )
373
+ await self ._write_chunk_to_store (store_path , chunk_array , chunk_spec )
364
374
365
- async def _write_chunk_to_store (self , store_path : StorePath , chunk_array : np .ndarray ):
375
+ async def _write_chunk_to_store (
376
+ self , store_path : StorePath , chunk_array : np .ndarray , chunk_spec : ArraySpec
377
+ ):
366
378
if np .all (chunk_array == self .metadata .fill_value ):
367
379
# chunks that only contain fill_value will be removed
368
380
await store_path .delete ()
369
381
else :
370
- chunk_bytes = await self .codec_pipeline .encode (chunk_array )
382
+ chunk_bytes = await self .codec_pipeline .encode (
383
+ chunk_array , chunk_spec , self .runtime_configuration
384
+ )
371
385
if chunk_bytes is None :
372
386
await store_path .delete ()
373
387
else :
0 commit comments