Skip to content

Commit f5b96b7

Browse files
Fixes gh-1334
A non-empty array which is effectively 1D (only one dimension has size greater than one) should be marked as both C- and F- contiguous. ``` In [1]: import dpctl.tensor as dpt In [2]: a = dpt.ones((2, 3)) ...: dpt.reshape(a, (1, 6, 1)).flags Out[2]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [3]: a = dpt.ones((2, 3), order='F') ...: dpt.reshape(a, (1, 6, 1), order='F').flags Out[3]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True In [4]: a = dpt.ones((2, 3, 4)) ...: dpt.sum(a, axis=(1, 2), keepdims=True).flags Out[4]: C_CONTIGUOUS : True F_CONTIGUOUS : True WRITABLE : True ```
1 parent 346200e commit f5b96b7

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

dpctl/tensor/_stride_utils.pxi

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ cdef int _from_input_shape_strides(
6464
cdef int j
6565
cdef bint all_incr = 1
6666
cdef bint all_decr = 1
67-
cdef bint all_incr_modified = 0
68-
cdef bint all_decr_modified = 0
67+
cdef bint strides_inspected = 0
6968
cdef Py_ssize_t elem_count = 1
7069
cdef Py_ssize_t min_shift = 0
7170
cdef Py_ssize_t max_shift = 0
@@ -167,27 +166,33 @@ cdef int _from_input_shape_strides(
167166
while (j < nd and shape_arr[j] == 1):
168167
j = j + 1
169168
if j < nd:
169+
strides_inspected = 1
170170
if all_incr:
171-
all_incr_modified = 1
172171
all_incr = (
173172
(strides_arr[i] > 0) and
174173
(strides_arr[j] > 0) and
175174
(strides_arr[i] <= strides_arr[j])
176175
)
177176
if all_decr:
178-
all_decr_modified = 1
179177
all_decr = (
180178
(strides_arr[i] > 0) and
181179
(strides_arr[j] > 0) and
182180
(strides_arr[i] >= strides_arr[j])
183181
)
184182
i = j
185183
else:
184+
if not strides_inspected:
185+
# all dimensions have size 1 except
186+
# dimension 'i'. Array is both C and F
187+
# contiguous
188+
strides_inspected = 1
189+
all_incr = (strides_arr[i] == 1)
190+
all_decr = all_incr
186191
break
187192
# should only set contig flags on actually obtained
188193
# values, rather than default values
189-
all_incr = all_incr and all_incr_modified
190-
all_decr = all_decr and all_decr_modified
194+
all_incr = all_incr and strides_inspected
195+
all_decr = all_decr and strides_inspected
191196
if all_incr and all_decr:
192197
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
193198
elif all_incr:

0 commit comments

Comments
 (0)