Skip to content

Commit ddafefe

Browse files
Fixes #693
Constructor make more rigorous checks for contiguity flags. When constructor is checking contiguity of the 1D array, it needs to account for 1-element array no matter the value of strides[0]. Similarly for higher dimensionality, contiguity check should not take into account slots with dimensions of 1. X1 = usm_ndarray((1,3,1), 'i4', 'device', strides=(0,1,0)) X2 = usm_ndarray((2,1,3), 'i4', 'device', strides=(3, 0, 1)) X3 = usm_ndarray((2,1,3), 'i4', 'device', strides=(1, 0, 2))
1 parent 29c2cbc commit ddafefe

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

dpctl/tensor/_stride_utils.pxi

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ cdef int _from_input_shape_strides(
6161
Otherwise they are set to NULL
6262
"""
6363
cdef int i
64+
cdef int j
6465
cdef int all_incr = 1
6566
cdef int all_decr = 1
6667
cdef Py_ssize_t elem_count = 1
@@ -115,6 +116,15 @@ cdef int _from_input_shape_strides(
115116
contig[0] = USM_ARRAY_C_CONTIGUOUS
116117
else:
117118
contig[0] = USM_ARRAY_F_CONTIGUOUS
119+
if nd == 1:
120+
contig[0] = USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS
121+
else:
122+
j = 0
123+
for i in range(nd):
124+
if shape_arr[i] > 1:
125+
j = j + 1
126+
if j < 2:
127+
contig[0] = USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS
118128
min_disp[0] = 0
119129
max_disp[0] = (elem_count - 1)
120130
strides_ptr[0] = <Py_ssize_t *>(<size_t>0)
@@ -137,26 +147,42 @@ cdef int _from_input_shape_strides(
137147
min_disp[0] = min_shift
138148
max_disp[0] = max_shift
139149
if max_shift == min_shift + (elem_count - 1):
150+
if elem_count == 1:
151+
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
152+
return 0
140153
if nd == 1:
141154
if strides_arr[0] == 1:
142155
contig[0] = USM_ARRAY_C_CONTIGUOUS
143156
else:
144157
contig[0] = 0
145158
return 0
146-
for i in range(0, nd - 1):
147-
if all_incr:
148-
all_incr = (
149-
(strides_arr[i] > 0) and
150-
(strides_arr[i+1] > 0) and
151-
(strides_arr[i] <= strides_arr[i + 1])
152-
)
153-
if all_decr:
154-
all_decr = (
155-
(strides_arr[i] > 0) and
156-
(strides_arr[i+1] > 0) and
157-
(strides_arr[i] >= strides_arr[i + 1])
158-
)
159-
if all_incr:
159+
i = 0
160+
while i < nd:
161+
if shape_arr[i] == 1:
162+
i = i + 1
163+
continue
164+
j = i + 1
165+
while (j < nd and shape_arr[j] == 1):
166+
j = j + 1
167+
if j < nd:
168+
if all_incr:
169+
all_incr = (
170+
(strides_arr[i] > 0) and
171+
(strides_arr[j] > 0) and
172+
(strides_arr[i] <= strides_arr[j])
173+
)
174+
if all_decr:
175+
all_decr = (
176+
(strides_arr[i] > 0) and
177+
(strides_arr[j] > 0) and
178+
(strides_arr[i] >= strides_arr[j])
179+
)
180+
i = j
181+
else:
182+
break
183+
if all_incr and all_decr:
184+
contig[0] = (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)
185+
elif all_incr:
160186
contig[0] = USM_ARRAY_F_CONTIGUOUS
161187
elif all_decr:
162188
contig[0] = USM_ARRAY_C_CONTIGUOUS

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def test_allocate_usm_ndarray(shape, usm_type):
5959
assert X.shape == X.__sycl_usm_array_interface__["shape"]
6060

6161

62+
def test_usm_ndarray_flags():
63+
assert dpt.usm_ndarray((5,)).flags == 3
64+
assert dpt.usm_ndarray((5, 2)).flags == 1
65+
assert dpt.usm_ndarray((5, 2), order="F").flags == 2
66+
assert dpt.usm_ndarray((5, 1, 2), order="F").flags == 2
67+
assert dpt.usm_ndarray((5, 1, 2), strides=(2, 0, 1)).flags == 1
68+
assert dpt.usm_ndarray((5, 1, 2), strides=(1, 0, 5)).flags == 2
69+
assert dpt.usm_ndarray((5, 1, 1), strides=(1, 0, 1)).flags == 3
70+
71+
6272
@pytest.mark.parametrize(
6373
"dtype",
6474
[
@@ -703,11 +713,10 @@ def relaxed_strides_equal(st1, st2, sh):
703713
5,
704714
)
705715
X = dpt.usm_ndarray(sh_s, dtype="d")
706-
expected_flags = X.flags
707716
X.shape = sh_f
708717
assert X.shape == sh_f
709718
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
710-
assert X.flags == expected_flags
719+
assert X.flags & 1, "reshaped array expected to be C-contiguous"
711720

712721
sh_s = (
713722
2,
@@ -842,6 +851,10 @@ def test_reshape():
842851
W = dpt.reshape(Z, (-1,), order="C")
843852
assert W.shape == (Z.size,)
844853

854+
X = dpt.usm_ndarray((1,))
855+
Y = dpt.reshape(X, X.shape)
856+
assert Y.flags == X.flags
857+
845858

846859
def test_transpose():
847860
n, m = 2, 3

0 commit comments

Comments
 (0)