Skip to content

Commit 1320d39

Browse files
Merge pull request #1202 from IntelPython/fix-gh-1201
Fixed regression in iteration space simplifier, fixed arguments of numpy to usm_ndarray copying function
2 parents 004e3d9 + 8e43c77 commit 1320d39

File tree

3 files changed

+72
-56
lines changed

3 files changed

+72
-56
lines changed

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
211211
}
212212

213213
// Minumum and maximum element offsets for source np.ndarray
214-
py::ssize_t npy_src_min_nelem_offset(0);
215-
py::ssize_t npy_src_max_nelem_offset(0);
214+
py::ssize_t npy_src_min_nelem_offset(src_offset);
215+
py::ssize_t npy_src_max_nelem_offset(src_offset);
216216
for (int i = 0; i < nd; ++i) {
217217
if (simplified_src_strides[i] < 0) {
218218
npy_src_min_nelem_offset +=

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,22 @@ void simplify_iteration_space(int &nd,
141141
assert(simplified_shape.size() == static_cast<size_t>(nd));
142142

143143
simplified_src_strides.reserve(nd);
144-
simplified_src_strides.push_back(
145-
(src_strides[0] >= 0) ? src_strides[0] : -src_strides[0]);
146-
if ((src_strides[0] < 0) && (shape[0] > 1)) {
147-
src_offset += (shape[0] - 1) * src_strides[0];
148-
}
149-
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
150-
151144
simplified_dst_strides.reserve(nd);
152-
simplified_dst_strides.push_back(
153-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
154-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
155-
dst_offset += (shape[0] - 1) * dst_strides[0];
145+
146+
if (src_strides[0] < 0 && dst_strides[0] < 0) {
147+
simplified_src_strides.push_back(-src_strides[0]);
148+
simplified_dst_strides.push_back(-dst_strides[0]);
149+
if (shape[0] > 1) {
150+
src_offset += (shape[0] - 1) * src_strides[0];
151+
dst_offset += (shape[0] - 1) * dst_strides[0];
152+
}
153+
}
154+
else {
155+
simplified_src_strides.push_back(src_strides[0]);
156+
simplified_dst_strides.push_back(dst_strides[0]);
156157
}
158+
159+
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
157160
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
158161
}
159162
}
@@ -226,27 +229,28 @@ void simplify_iteration_space_3(
226229
assert(simplified_shape.size() == static_cast<size_t>(nd));
227230

228231
simplified_src1_strides.reserve(nd);
229-
simplified_src1_strides.push_back(
230-
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
231-
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
232-
src1_offset += src1_strides[0] * (shape[0] - 1);
233-
}
234-
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
235-
236232
simplified_src2_strides.reserve(nd);
237-
simplified_src2_strides.push_back(
238-
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
239-
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
240-
src2_offset += src2_strides[0] * (shape[0] - 1);
241-
}
242-
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
243-
244233
simplified_dst_strides.reserve(nd);
245-
simplified_dst_strides.push_back(
246-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
247-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
248-
dst_offset += dst_strides[0] * (shape[0] - 1);
234+
235+
if ((src1_strides[0] < 0) && (src2_strides[0] < 0) &&
236+
(dst_strides[0] < 0)) {
237+
simplified_src1_strides.push_back(-src1_strides[0]);
238+
simplified_src2_strides.push_back(-src2_strides[0]);
239+
simplified_dst_strides.push_back(-dst_strides[0]);
240+
if (shape[0] > 1) {
241+
src1_offset += src1_strides[0] * (shape[0] - 1);
242+
src2_offset += src2_strides[0] * (shape[0] - 1);
243+
dst_offset += dst_strides[0] * (shape[0] - 1);
244+
}
245+
}
246+
else {
247+
simplified_src1_strides.push_back(src1_strides[0]);
248+
simplified_src2_strides.push_back(src2_strides[0]);
249+
simplified_dst_strides.push_back(dst_strides[0]);
249250
}
251+
252+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
253+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
250254
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
251255
}
252256
}
@@ -333,35 +337,34 @@ void simplify_iteration_space_4(
333337
assert(simplified_shape.size() == static_cast<size_t>(nd));
334338

335339
simplified_src1_strides.reserve(nd);
336-
simplified_src1_strides.push_back(
337-
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
338-
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
339-
src1_offset += src1_strides[0] * (shape[0] - 1);
340-
}
341-
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
342-
343340
simplified_src2_strides.reserve(nd);
344-
simplified_src2_strides.push_back(
345-
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
346-
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
347-
src2_offset += src2_strides[0] * (shape[0] - 1);
348-
}
349-
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
350-
351341
simplified_src3_strides.reserve(nd);
352-
simplified_src3_strides.push_back(
353-
(src3_strides[0] >= 0) ? src3_strides[0] : -src3_strides[0]);
354-
if ((src3_strides[0] < 0) && (shape[0] > 1)) {
355-
src3_offset += src3_strides[0] * (shape[0] - 1);
356-
}
357-
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
358-
359342
simplified_dst_strides.reserve(nd);
360-
simplified_dst_strides.push_back(
361-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
362-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
363-
dst_offset += dst_strides[0] * (shape[0] - 1);
343+
344+
if ((src1_strides[0] < 0) && (src2_strides[0] < 0) &&
345+
(src3_strides[0] < 0) && (dst_strides[0] < 0))
346+
{
347+
simplified_src1_strides.push_back(-src1_strides[0]);
348+
simplified_src2_strides.push_back(-src2_strides[0]);
349+
simplified_src3_strides.push_back(-src3_strides[0]);
350+
simplified_dst_strides.push_back(-dst_strides[0]);
351+
if (shape[0] > 1) {
352+
src1_offset += src1_strides[0] * (shape[0] - 1);
353+
src2_offset += src2_strides[0] * (shape[0] - 1);
354+
src3_offset += src3_strides[0] * (shape[0] - 1);
355+
dst_offset += dst_strides[0] * (shape[0] - 1);
356+
}
357+
}
358+
else {
359+
simplified_src1_strides.push_back(src1_strides[0]);
360+
simplified_src2_strides.push_back(src2_strides[0]);
361+
simplified_src3_strides.push_back(src3_strides[0]);
362+
simplified_dst_strides.push_back(dst_strides[0]);
364363
}
364+
365+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
366+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
367+
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
365368
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
366369
}
367370
}

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,3 +2059,16 @@ def test_byte_bounds():
20592059
y = x[::-1, ::2]
20602060
lo, hi = y._byte_bounds
20612061
assert hi - lo == (n0 * n1 - 1) * x.itemsize
2062+
2063+
2064+
def test_gh_1201():
2065+
n = 100
2066+
a = np.flipud(np.arange(n, dtype="i4"))
2067+
try:
2068+
b = dpt.asarray(a)
2069+
except dpctl.SyclDeviceCreationError:
2070+
pytest.skip("No SYCL devices available")
2071+
assert (dpt.asnumpy(b) == a).all()
2072+
c = dpt.flip(dpt.empty(a.shape, dtype=a.dtype))
2073+
c[:] = a
2074+
assert (dpt.asnumpy(c) == a).all()

0 commit comments

Comments
 (0)