Skip to content

Commit 369eda4

Browse files
Fixes gh-1331
``` In [1]: import dpctl.tensor as dpt In [2]: a = dpt.ones((2, 3)) ...: dpt.reshape(a, (1, 6, 1)).flags Out[2]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [3]: a = dpt.ones((2, 3), order='F') ...: dpt.reshape(a, (1, 6, 1), order='F').flags Out[3]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [4]: a = dpt.ones((2, 3, 4)) ...: dpt.sum(a, axis=(1, 2), keepdims=True).flags Out[4]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True ```
1 parent 346200e commit 369eda4

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

dpctl/tensor/_stride_utils.pxi

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ cdef int _from_input_shape_strides(
6464
cdef int j
6565
cdef bint all_incr = 1
6666
cdef bint all_decr = 1
67-
cdef bint all_incr_modified = 0
68-
cdef bint all_decr_modified = 0
67+
cdef bint strides_inspected = 0
6968
cdef Py_ssize_t elem_count = 1
7069
cdef Py_ssize_t min_shift = 0
7170
cdef Py_ssize_t max_shift = 0
@@ -167,27 +166,27 @@ cdef int _from_input_shape_strides(
167166
while (j < nd and shape_arr[j] == 1):
168167
j = j + 1
169168
if j < nd:
169+
strides_inspected = 1
170170
if all_incr:
171-
all_incr_modified = 1
172171
all_incr = (
173172
(strides_arr[i] > 0) and
174173
(strides_arr[j] > 0) and
175174
(strides_arr[i] <= strides_arr[j])
176175
)
177176
if all_decr:
178-
all_decr_modified = 1
179177
all_decr = (
180178
(strides_arr[i] > 0) and
181179
(strides_arr[j] > 0) and
182180
(strides_arr[i] >= strides_arr[j])
183181
)
184182
i = j
185183
else:
184+
strides_inspected = 1
186185
break
187186
# should only set contig flags on actually obtained
188187
# values, rather than default values
189-
all_incr = all_incr and all_incr_modified
190-
all_decr = all_decr and all_decr_modified
188+
all_incr = all_incr and strides_inspected
189+
all_decr = all_decr and strides_inspected
191190
if all_incr and all_decr:
192191
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
193192
elif all_incr:

0 commit comments

Comments
 (0)