@@ -54,6 +54,9 @@ def _create_strided_array(buf, numdims, idims, dtype, is_device, offset, strides
54
54
55
55
def _create_empty_array (numdims , idims , dtype ):
56
56
out_arr = c_void_ptr_t (0 )
57
+
58
+ if numdims == 0 : return out_arr
59
+
57
60
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
58
61
safe_call (backend .get ().af_create_handle (c_pointer (out_arr ),
59
62
numdims , c_pointer (c_dims ), dtype .value ))
@@ -160,19 +163,18 @@ def _slice_to_length(key, dim):
160
163
161
164
def _get_info (dims , buf_len ):
162
165
elements = 1
163
- numdims = len (dims )
164
- idims = [1 ]* 4
165
-
166
- for i in range (numdims ):
167
- elements *= dims [i ]
168
- idims [i ] = dims [i ]
169
-
170
- if (elements == 0 ):
171
- if (buf_len != 0 ):
172
- idims = [buf_len , 1 , 1 , 1 ]
173
- numdims = 1
174
- else :
175
- raise RuntimeError ("Invalid size" )
166
+ numdims = 0
167
+ if dims :
168
+ numdims = len (dims )
169
+ idims = [1 ]* 4
170
+ for i in range (numdims ):
171
+ elements *= dims [i ]
172
+ idims [i ] = dims [i ]
173
+ elif (buf_len != 0 ):
174
+ idims = [buf_len , 1 , 1 , 1 ]
175
+ numdims = 1
176
+ else :
177
+ raise RuntimeError ("Invalid size" )
176
178
177
179
return numdims , idims
178
180
@@ -382,7 +384,7 @@ class Array(BaseArray):
382
384
# arrayfire's __radd__() instead of numpy's __add__()
383
385
__array_priority__ = 30
384
386
385
- def __init__ (self , src = None , dims = ( 0 ,) , dtype = None , is_device = False , offset = None , strides = None ):
387
+ def __init__ (self , src = None , dims = None , dtype = None , is_device = False , offset = None , strides = None ):
386
388
387
389
super (Array , self ).__init__ ()
388
390
@@ -449,10 +451,12 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False, offset=None
449
451
if type_char is None :
450
452
type_char = 'f'
451
453
452
- numdims = len (dims )
454
+ numdims = len (dims ) if dims else 0
455
+
453
456
idims = [1 ] * 4
454
457
for n in range (numdims ):
455
458
idims [n ] = dims [n ]
459
+
456
460
self .arr = _create_empty_array (numdims , idims , to_dtype [type_char ])
457
461
458
462
def as_type (self , ty ):
0 commit comments