44import math
55import operator
66import re
7- from collections .abc import MutableMapping
87from functools import reduce
98from typing import Any
109
1110import numpy as np
12- from numcodecs .compat import ensure_bytes , ensure_ndarray
11+ from numcodecs .compat import ensure_bytes
1312
1413from zarr ._storage .store import _prefix_to_attrs_key , assert_zarr_v3_api_available
1514from zarr .attrs import Attributes
3534from zarr .storage import (
3635 _get_hierarchy_metadata ,
3736 _prefix_to_array_key ,
37+ KVStore ,
3838 getsize ,
3939 listdir ,
4040 normalize_store_arg ,
5151 normalize_shape ,
5252 normalize_storage_path ,
5353 PartialReadBuffer ,
54+ ensure_ndarray_like
5455)
5556
5657
@@ -98,6 +99,12 @@ class Array:
9899
99100 .. versionadded:: 2.11
100101
102+ meta_array : array-like, optional
103+ An array instance to use for determining arrays to create and return
104+ to users. Use `numpy.empty(())` by default.
105+
106+ .. versionadded:: 2.13
107+
101108
102109 Attributes
103110 ----------
@@ -129,6 +136,7 @@ class Array:
129136 vindex
130137 oindex
131138 write_empty_chunks
139+ meta_array
132140
133141 Methods
134142 -------
@@ -163,6 +171,7 @@ def __init__(
163171 partial_decompress = False ,
164172 write_empty_chunks = True ,
165173 zarr_version = None ,
174+ meta_array = None ,
166175 ):
167176 # N.B., expect at this point store is fully initialized with all
168177 # configuration metadata fully specified and normalized
@@ -191,8 +200,11 @@ def __init__(
191200 self ._is_view = False
192201 self ._partial_decompress = partial_decompress
193202 self ._write_empty_chunks = write_empty_chunks
203+ if meta_array is not None :
204+ self ._meta_array = np .empty_like (meta_array , shape = ())
205+ else :
206+ self ._meta_array = np .empty (())
194207 self ._version = zarr_version
195-
196208 if self ._version == 3 :
197209 self ._data_key_prefix = 'data/root/' + self ._key_prefix
198210 self ._data_path = 'data/root/' + self ._path
@@ -555,6 +567,13 @@ def write_empty_chunks(self) -> bool:
555567 """
556568 return self ._write_empty_chunks
557569
570+ @property
571+ def meta_array (self ):
572+ """An array-like instance to use for determining arrays to create and return
573+ to users.
574+ """
575+ return self ._meta_array
576+
558577 def __eq__ (self , other ):
559578 return (
560579 isinstance (other , Array ) and
@@ -929,7 +948,7 @@ def _get_basic_selection_zd(self, selection, out=None, fields=None):
929948
930949 except KeyError :
931950 # chunk not initialized
932- chunk = np .zeros ( (), dtype = self ._dtype )
951+ chunk = np .zeros_like ( self . _meta_array , shape = (), dtype = self ._dtype )
933952 if self ._fill_value is not None :
934953 chunk .fill (self ._fill_value )
935954
@@ -1233,7 +1252,8 @@ def _get_selection(self, indexer, out=None, fields=None):
12331252
12341253 # setup output array
12351254 if out is None :
1236- out = np .empty (out_shape , dtype = out_dtype , order = self ._order )
1255+ out = np .empty_like (self ._meta_array , shape = out_shape ,
1256+ dtype = out_dtype , order = self ._order )
12371257 else :
12381258 check_array_shape ('out' , out , out_shape )
12391259
@@ -1607,9 +1627,13 @@ def set_coordinate_selection(self, selection, value, fields=None):
16071627 # setup indexer
16081628 indexer = CoordinateIndexer (selection , self )
16091629
1610- # handle value - need to flatten
1630+ # handle value - need ndarray-like flatten value
16111631 if not is_scalar (value , self ._dtype ):
1612- value = np .asanyarray (value )
1632+ try :
1633+ value = ensure_ndarray_like (value )
1634+ except TypeError :
1635+ # Handle types like `list` or `tuple`
1636+ value = np .array (value , like = self ._meta_array )
16131637 if hasattr (value , 'shape' ) and len (value .shape ) > 1 :
16141638 value = value .reshape (- 1 )
16151639
@@ -1712,7 +1736,7 @@ def _set_basic_selection_zd(self, selection, value, fields=None):
17121736
17131737 except KeyError :
17141738 # chunk not initialized
1715- chunk = np .zeros ( (), dtype = self ._dtype )
1739+ chunk = np .zeros_like ( self . _meta_array , shape = (), dtype = self ._dtype )
17161740 if self ._fill_value is not None :
17171741 chunk .fill (self ._fill_value )
17181742
@@ -1772,7 +1796,7 @@ def _set_selection(self, indexer, value, fields=None):
17721796 pass
17731797 else :
17741798 if not hasattr (value , 'shape' ):
1775- value = np .asanyarray (value )
1799+ value = np .asanyarray (value , like = self . _meta_array )
17761800 check_array_shape ('value' , value , sel_shape )
17771801
17781802 # iterate over chunks in range
@@ -1840,8 +1864,11 @@ def _process_chunk(
18401864 self ._dtype != object ):
18411865
18421866 dest = out [out_selection ]
1867+ # Assume that array-like objects that doesn't have a
1868+ # `writeable` flag is writable.
1869+ dest_is_writable = getattr (dest , "writeable" , True )
18431870 write_direct = (
1844- dest . flags . writeable and
1871+ dest_is_writable and
18451872 (
18461873 (self ._order == 'C' and dest .flags .c_contiguous ) or
18471874 (self ._order == 'F' and dest .flags .f_contiguous )
@@ -1858,7 +1885,7 @@ def _process_chunk(
18581885 cdata = cdata .read_full ()
18591886 self ._compressor .decode (cdata , dest )
18601887 else :
1861- chunk = ensure_ndarray (cdata ).view (self ._dtype )
1888+ chunk = ensure_ndarray_like (cdata ).view (self ._dtype )
18621889 chunk = chunk .reshape (self ._chunks , order = self ._order )
18631890 np .copyto (dest , chunk )
18641891 return
@@ -1868,7 +1895,7 @@ def _process_chunk(
18681895 if partial_read_decode :
18691896 cdata .prepare_chunk ()
18701897 # size of chunk
1871- tmp = np .empty ( self ._chunks , dtype = self .dtype )
1898+ tmp = np .empty_like ( self . _meta_array , shape = self ._chunks , dtype = self .dtype )
18721899 index_selection = PartialChunkIterator (chunk_selection , self .chunks )
18731900 for start , nitems , partial_out_selection in index_selection :
18741901 expected_shape = [
@@ -1925,7 +1952,7 @@ def _chunk_getitem(self, chunk_coords, chunk_selection, out, out_selection,
19251952 """
19261953 out_is_ndarray = True
19271954 try :
1928- out = ensure_ndarray (out )
1955+ out = ensure_ndarray_like (out )
19291956 except TypeError :
19301957 out_is_ndarray = False
19311958
@@ -1960,7 +1987,7 @@ def _chunk_getitems(self, lchunk_coords, lchunk_selection, out, lout_selection,
19601987 """
19611988 out_is_ndarray = True
19621989 try :
1963- out = ensure_ndarray (out )
1990+ out = ensure_ndarray_like (out )
19641991 except TypeError : # pragma: no cover
19651992 out_is_ndarray = False
19661993
@@ -2082,7 +2109,9 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None):
20822109 if is_scalar (value , self ._dtype ):
20832110
20842111 # setup array filled with value
2085- chunk = np .empty (self ._chunks , dtype = self ._dtype , order = self ._order )
2112+ chunk = np .empty_like (
2113+ self ._meta_array , shape = self ._chunks , dtype = self ._dtype , order = self ._order
2114+ )
20862115 chunk .fill (value )
20872116
20882117 else :
@@ -2102,14 +2131,18 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None):
21022131
21032132 # chunk not initialized
21042133 if self ._fill_value is not None :
2105- chunk = np .empty (self ._chunks , dtype = self ._dtype , order = self ._order )
2134+ chunk = np .empty_like (
2135+ self ._meta_array , shape = self ._chunks , dtype = self ._dtype , order = self ._order
2136+ )
21062137 chunk .fill (self ._fill_value )
21072138 elif self ._dtype == object :
21082139 chunk = np .empty (self ._chunks , dtype = self ._dtype , order = self ._order )
21092140 else :
21102141 # N.B., use zeros here so any region beyond the array has consistent
21112142 # and compressible data
2112- chunk = np .zeros (self ._chunks , dtype = self ._dtype , order = self ._order )
2143+ chunk = np .zeros_like (
2144+ self ._meta_array , shape = self ._chunks , dtype = self ._dtype , order = self ._order
2145+ )
21132146
21142147 else :
21152148
@@ -2159,7 +2192,7 @@ def _decode_chunk(self, cdata, start=None, nitems=None, expected_shape=None):
21592192 chunk = f .decode (chunk )
21602193
21612194 # view as numpy array with correct dtype
2162- chunk = ensure_ndarray (chunk )
2195+ chunk = ensure_ndarray_like (chunk )
21632196 # special case object dtype, because incorrect handling can lead to
21642197 # segfaults and other bad things happening
21652198 if self ._dtype != object :
@@ -2186,7 +2219,7 @@ def _encode_chunk(self, chunk):
21862219 chunk = f .encode (chunk )
21872220
21882221 # check object encoding
2189- if ensure_ndarray (chunk ).dtype == object :
2222+ if ensure_ndarray_like (chunk ).dtype == object :
21902223 raise RuntimeError ('cannot write object array without object codec' )
21912224
21922225 # compress
@@ -2196,7 +2229,7 @@ def _encode_chunk(self, chunk):
21962229 cdata = chunk
21972230
21982231 # ensure in-memory data is immutable and easy to compare
2199- if isinstance (self .chunk_store , MutableMapping ):
2232+ if isinstance (self .chunk_store , KVStore ):
22002233 cdata = ensure_bytes (cdata )
22012234
22022235 return cdata
@@ -2354,12 +2387,22 @@ def hexdigest(self, hashname="sha1"):
23542387 return checksum
23552388
23562389 def __getstate__ (self ):
2357- return (self ._store , self ._path , self ._read_only , self ._chunk_store ,
2358- self ._synchronizer , self ._cache_metadata , self ._attrs .cache ,
2359- self ._partial_decompress , self ._write_empty_chunks , self ._version )
2390+ return {
2391+ "store" : self ._store ,
2392+ "path" : self ._path ,
2393+ "read_only" : self ._read_only ,
2394+ "chunk_store" : self ._chunk_store ,
2395+ "synchronizer" : self ._synchronizer ,
2396+ "cache_metadata" : self ._cache_metadata ,
2397+ "cache_attrs" : self ._attrs .cache ,
2398+ "partial_decompress" : self ._partial_decompress ,
2399+ "write_empty_chunks" : self ._write_empty_chunks ,
2400+ "zarr_version" : self ._version ,
2401+ "meta_array" : self ._meta_array ,
2402+ }
23602403
23612404 def __setstate__ (self , state ):
2362- self .__init__ (* state )
2405+ self .__init__ (** state )
23632406
23642407 def _synchronized_op (self , f , * args , ** kwargs ):
23652408
@@ -2466,7 +2509,7 @@ def append(self, data, axis=0):
24662509
24672510 Parameters
24682511 ----------
2469- data : array_like
2512+ data : array-like
24702513 Data to be appended.
24712514 axis : int
24722515 Axis along which to append.
@@ -2502,7 +2545,7 @@ def _append_nosync(self, data, axis=0):
25022545
25032546 # ensure data is array-like
25042547 if not hasattr (data , 'shape' ):
2505- data = np .asanyarray (data )
2548+ data = np .asanyarray (data , like = self . _meta_array )
25062549
25072550 # ensure shapes are compatible for non-append dimensions
25082551 self_shape_preserved = tuple (s for i , s in enumerate (self ._shape )
0 commit comments