Skip to content

Commit 4a8131c

Browse files
Change to allow as_usm_memory to consume SUAI with zero dimensions
1 parent 262f098 commit 4a8131c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

dpctl/memory/_sycl_usm_array_interface_utils.pxi

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,38 @@ cdef object _pointers_from_shape_and_stride(
8888
8989
Returns: tuple(min_disp, nbytes)
9090
"""
91+
cdef Py_ssize_t nelems = 1
92+
cdef Py_ssize_t min_disp = 0
93+
cdef Py_ssize_t max_disp = 0
94+
cdef int i
95+
cdef Py_ssize_t sh_i = 0
96+
cdef Py_ssize_t str_i = 0
9197
if (nd > 0):
9298
if (ary_strides is None):
9399
nelems = 1
94100
for si in ary_shape:
95101
sh_i = int(si)
96-
if (sh_i <= 0):
102+
if (sh_i < 0):
97103
raise ValueError("Array shape elements need to be positive")
98104
nelems = nelems * sh_i
99-
return (ary_offset, nelems * itemsize)
105+
return (ary_offset, max(nelems, 1) * itemsize)
100106
else:
101107
min_disp = ary_offset
102108
max_disp = ary_offset
103109
for i in range(nd):
104110
str_i = int(ary_strides[i])
105111
sh_i = int(ary_shape[i])
106-
if (sh_i <= 0):
112+
if (sh_i < 0):
107113
raise ValueError("Array shape elements need to be positive")
108-
if (str_i > 0):
109-
max_disp += str_i * (sh_i - 1)
114+
if (sh_i > 0):
115+
if (str_i > 0):
116+
max_disp += str_i * (sh_i - 1)
117+
else:
118+
min_disp += str_i * (sh_i - 1)
110119
else:
111-
min_disp += str_i * (sh_i - 1);
120+
nelems = 0
121+
if nelems == 0:
122+
return (ary_offset, itemsize)
112123
return (min_disp, (max_disp - min_disp + 1) * itemsize)
113124
elif (nd == 0):
114125
return (ary_offset, itemsize)

0 commit comments

Comments
 (0)