20
20
21
21
22
22
# from zarr.array_v2 import ArrayV2
23
+ from zarr .buffer import Buffer , Factory , NDArrayLike , NDBuffer
23
24
from zarr .codecs import BytesCodec
24
25
from zarr .codecs .pipeline import CodecPipeline
25
26
from zarr .common import (
@@ -147,7 +148,7 @@ async def open(
147
148
assert zarr_json_bytes is not None
148
149
return cls .from_dict (
149
150
store_path ,
150
- json .loads (zarr_json_bytes ),
151
+ json .loads (zarr_json_bytes . to_bytes () ),
151
152
)
152
153
153
154
@classmethod
@@ -160,7 +161,7 @@ async def open_auto(
160
161
if v3_metadata_bytes is not None :
161
162
return cls .from_dict (
162
163
store_path ,
163
- json .loads (v3_metadata_bytes ),
164
+ json .loads (v3_metadata_bytes . to_bytes () ),
164
165
)
165
166
else :
166
167
raise ValueError ("no v2 support yet" )
@@ -186,7 +187,9 @@ def dtype(self) -> np.dtype[Any]:
186
187
def attrs (self ) -> dict [str , Any ]:
187
188
return self .metadata .attributes
188
189
189
- async def getitem (self , selection : Selection ) -> npt .NDArray [Any ]:
190
+ async def getitem (
191
+ self , selection : Selection , * , factory : Factory .Create = NDBuffer .create
192
+ ) -> NDArrayLike :
190
193
assert isinstance (self .metadata .chunk_grid , RegularChunkGrid )
191
194
indexer = BasicIndexer (
192
195
selection ,
@@ -195,10 +198,8 @@ async def getitem(self, selection: Selection) -> npt.NDArray[Any]:
195
198
)
196
199
197
200
# setup output array
198
- out = np .zeros (
199
- indexer .shape ,
200
- dtype = self .metadata .dtype ,
201
- order = self .order ,
201
+ out = factory (
202
+ shape = indexer .shape , dtype = self .metadata .dtype , order = self .order , fill_value = 0
202
203
)
203
204
204
205
# reading chunks and decoding them
@@ -210,21 +211,17 @@ async def getitem(self, selection: Selection) -> npt.NDArray[Any]:
210
211
self ._read_chunk ,
211
212
config .get ("async.concurrency" ),
212
213
)
213
-
214
- if out .shape :
215
- return out
216
- else :
217
- return out [()]
214
+ return out .as_ndarray_like ()
218
215
219
216
async def _save_metadata (self ) -> None :
220
- await (self .store_path / ZARR_JSON ).set (self .metadata .to_bytes ())
217
+ await (self .store_path / ZARR_JSON ).set (Buffer . from_bytes ( self .metadata .to_bytes () ))
221
218
222
219
async def _read_chunk (
223
220
self ,
224
221
chunk_coords : ChunkCoords ,
225
222
chunk_selection : SliceSelection ,
226
223
out_selection : SliceSelection ,
227
- out : npt . NDArray [ Any ] ,
224
+ out : NDBuffer ,
228
225
) -> None :
229
226
chunk_spec = self .metadata .get_chunk_spec (chunk_coords , self .order )
230
227
chunk_key_encoding = self .metadata .chunk_key_encoding
@@ -246,7 +243,12 @@ async def _read_chunk(
246
243
else :
247
244
out [out_selection ] = self .metadata .fill_value
248
245
249
- async def setitem (self , selection : Selection , value : npt .NDArray [Any ]) -> None :
246
+ async def setitem (
247
+ self ,
248
+ selection : Selection ,
249
+ value : NDArrayLike ,
250
+ factory : Factory .NDArrayLike = NDBuffer .from_ndarray_like ,
251
+ ) -> None :
250
252
assert isinstance (self .metadata .chunk_grid , RegularChunkGrid )
251
253
chunk_shape = self .metadata .chunk_grid .chunk_shape
252
254
indexer = BasicIndexer (
@@ -259,15 +261,19 @@ async def setitem(self, selection: Selection, value: npt.NDArray[Any]) -> None:
259
261
260
262
# check value shape
261
263
if np .isscalar (value ):
262
- # setting a scalar value
263
- pass
264
+ value = np .asanyarray (value )
264
265
else :
265
266
if not hasattr (value , "shape" ):
266
267
value = np .asarray (value , self .metadata .dtype )
267
268
assert value .shape == sel_shape
268
269
if value .dtype .name != self .metadata .dtype .name :
269
270
value = value .astype (self .metadata .dtype , order = "A" )
270
271
272
+ # We accept any ndarray like object from the user and convert it
273
+ # to a NDBuffer (or subclass). From this point onwards, we only pass
274
+ # Buffer and NDBuffer between components.
275
+ value = factory (value )
276
+
271
277
# merging with existing data and encoding chunks
272
278
await concurrent_map (
273
279
[
@@ -286,7 +292,7 @@ async def setitem(self, selection: Selection, value: npt.NDArray[Any]) -> None:
286
292
287
293
async def _write_chunk (
288
294
self ,
289
- value : npt . NDArray [ Any ] ,
295
+ value : NDBuffer ,
290
296
chunk_shape : ChunkCoords ,
291
297
chunk_coords : ChunkCoords ,
292
298
chunk_selection : SliceSelection ,
@@ -300,11 +306,9 @@ async def _write_chunk(
300
306
if is_total_slice (chunk_selection , chunk_shape ):
301
307
# write entire chunks
302
308
if np .isscalar (value ):
303
- chunk_array = np .empty (
304
- chunk_shape ,
305
- dtype = self .metadata .dtype ,
309
+ chunk_array = NDBuffer .create (
310
+ shape = chunk_shape , dtype = self .metadata .dtype , fill_value = value
306
311
)
307
- chunk_array .fill (value )
308
312
else :
309
313
chunk_array = value [out_selection ]
310
314
await self ._write_chunk_to_store (store_path , chunk_array , chunk_spec )
@@ -324,11 +328,11 @@ async def _write_chunk(
324
328
325
329
# merge new value
326
330
if chunk_bytes is None :
327
- chunk_array = np . empty (
328
- chunk_shape ,
331
+ chunk_array = NDBuffer . create (
332
+ shape = chunk_shape ,
329
333
dtype = self .metadata .dtype ,
334
+ fill_value = self .metadata .fill_value ,
330
335
)
331
- chunk_array .fill (self .metadata .fill_value )
332
336
else :
333
337
chunk_array = (
334
338
await self .codecs .decode (chunk_bytes , chunk_spec )
@@ -338,9 +342,9 @@ async def _write_chunk(
338
342
await self ._write_chunk_to_store (store_path , chunk_array , chunk_spec )
339
343
340
344
async def _write_chunk_to_store (
341
- self , store_path : StorePath , chunk_array : npt . NDArray [ Any ] , chunk_spec : ArraySpec
345
+ self , store_path : StorePath , chunk_array : NDBuffer , chunk_spec : ArraySpec
342
346
) -> None :
343
- if np . all ( chunk_array == self .metadata .fill_value ):
347
+ if chunk_array . all_equal ( self .metadata .fill_value ):
344
348
# chunks that only contain fill_value will be removed
345
349
await store_path .delete ()
346
350
else :
@@ -379,14 +383,14 @@ async def _delete_key(key: str) -> None:
379
383
)
380
384
381
385
# Write new metadata
382
- await (self .store_path / ZARR_JSON ).set (new_metadata .to_bytes ())
386
+ await (self .store_path / ZARR_JSON ).set (Buffer . from_bytes ( new_metadata .to_bytes () ))
383
387
return replace (self , metadata = new_metadata )
384
388
385
389
async def update_attributes (self , new_attributes : Dict [str , Any ]) -> AsyncArray :
386
390
new_metadata = replace (self .metadata , attributes = new_attributes )
387
391
388
392
# Write new metadata
389
- await (self .store_path / ZARR_JSON ).set (new_metadata .to_bytes ())
393
+ await (self .store_path / ZARR_JSON ).set (Buffer . from_bytes ( new_metadata .to_bytes () ))
390
394
return replace (self , metadata = new_metadata )
391
395
392
396
def __repr__ (self ) -> str :
0 commit comments