Skip to content

Commit eb97326

Browse files
Simplify_iteration_space was doing it incorrectly since gh-1198
For nd==1, and when simplifying iteration over multiple arrays, strides should only change sign if all strides are negative.
1 parent 004e3d9 commit eb97326

File tree

1 file changed

+57
-54
lines changed

1 file changed

+57
-54
lines changed

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,22 @@ void simplify_iteration_space(int &nd,
141141
assert(simplified_shape.size() == static_cast<size_t>(nd));
142142

143143
simplified_src_strides.reserve(nd);
144-
simplified_src_strides.push_back(
145-
(src_strides[0] >= 0) ? src_strides[0] : -src_strides[0]);
146-
if ((src_strides[0] < 0) && (shape[0] > 1)) {
147-
src_offset += (shape[0] - 1) * src_strides[0];
148-
}
149-
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
150-
151144
simplified_dst_strides.reserve(nd);
152-
simplified_dst_strides.push_back(
153-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
154-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
155-
dst_offset += (shape[0] - 1) * dst_strides[0];
145+
146+
if (src_strides[0] < 0 && dst_strides[0] < 0) {
147+
simplified_src_strides.push_back(-src_strides[0]);
148+
simplified_dst_strides.push_back(-dst_strides[0]);
149+
if (shape[0] > 1) {
150+
src_offset += (shape[0] - 1) * src_strides[0];
151+
dst_offset += (shape[0] - 1) * dst_strides[0];
152+
}
153+
}
154+
else {
155+
simplified_src_strides.push_back(src_strides[0]);
156+
simplified_dst_strides.push_back(dst_strides[0]);
156157
}
158+
159+
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
157160
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
158161
}
159162
}
@@ -226,27 +229,28 @@ void simplify_iteration_space_3(
226229
assert(simplified_shape.size() == static_cast<size_t>(nd));
227230

228231
simplified_src1_strides.reserve(nd);
229-
simplified_src1_strides.push_back(
230-
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
231-
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
232-
src1_offset += src1_strides[0] * (shape[0] - 1);
233-
}
234-
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
235-
236232
simplified_src2_strides.reserve(nd);
237-
simplified_src2_strides.push_back(
238-
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
239-
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
240-
src2_offset += src2_strides[0] * (shape[0] - 1);
241-
}
242-
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
243-
244233
simplified_dst_strides.reserve(nd);
245-
simplified_dst_strides.push_back(
246-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
247-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
248-
dst_offset += dst_strides[0] * (shape[0] - 1);
234+
235+
if ((src1_strides[0] < 0) && (src2_strides[0] < 0) &&
236+
(dst_strides[0] < 0)) {
237+
simplified_src1_strides.push_back(-src1_strides[0]);
238+
simplified_src2_strides.push_back(-src2_strides[0]);
239+
simplified_dst_strides.push_back(-dst_strides[0]);
240+
if (shape[0] > 1) {
241+
src1_offset += src1_strides[0] * (shape[0] - 1);
242+
src2_offset += src2_strides[0] * (shape[0] - 1);
243+
dst_offset += dst_strides[0] * (shape[0] - 1);
244+
}
245+
}
246+
else {
247+
simplified_src1_strides.push_back(src1_strides[0]);
248+
simplified_src2_strides.push_back(src2_strides[0]);
249+
simplified_dst_strides.push_back(dst_strides[0]);
249250
}
251+
252+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
253+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
250254
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
251255
}
252256
}
@@ -333,35 +337,34 @@ void simplify_iteration_space_4(
333337
assert(simplified_shape.size() == static_cast<size_t>(nd));
334338

335339
simplified_src1_strides.reserve(nd);
336-
simplified_src1_strides.push_back(
337-
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
338-
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
339-
src1_offset += src1_strides[0] * (shape[0] - 1);
340-
}
341-
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
342-
343340
simplified_src2_strides.reserve(nd);
344-
simplified_src2_strides.push_back(
345-
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
346-
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
347-
src2_offset += src2_strides[0] * (shape[0] - 1);
348-
}
349-
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
350-
351341
simplified_src3_strides.reserve(nd);
352-
simplified_src3_strides.push_back(
353-
(src3_strides[0] >= 0) ? src3_strides[0] : -src3_strides[0]);
354-
if ((src3_strides[0] < 0) && (shape[0] > 1)) {
355-
src3_offset += src3_strides[0] * (shape[0] - 1);
356-
}
357-
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
358-
359342
simplified_dst_strides.reserve(nd);
360-
simplified_dst_strides.push_back(
361-
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
362-
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
363-
dst_offset += dst_strides[0] * (shape[0] - 1);
343+
344+
if ((src1_strides[0] < 0) && (src2_strides[0] < 0) &&
345+
(src3_strides[0] < 0) && (dst_strides[0] < 0))
346+
{
347+
simplified_src1_strides.push_back(-src1_strides[0]);
348+
simplified_src2_strides.push_back(-src2_strides[0]);
349+
simplified_src3_strides.push_back(-src3_strides[0]);
350+
simplified_dst_strides.push_back(-dst_strides[0]);
351+
if (shape[0] > 1) {
352+
src1_offset += src1_strides[0] * (shape[0] - 1);
353+
src2_offset += src2_strides[0] * (shape[0] - 1);
354+
src3_offset += src3_strides[0] * (shape[0] - 1);
355+
dst_offset += dst_strides[0] * (shape[0] - 1);
356+
}
357+
}
358+
else {
359+
simplified_src1_strides.push_back(src1_strides[0]);
360+
simplified_src2_strides.push_back(src2_strides[0]);
361+
simplified_src3_strides.push_back(src3_strides[0]);
362+
simplified_dst_strides.push_back(dst_strides[0]);
364363
}
364+
365+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
366+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
367+
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
365368
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
366369
}
367370
}

0 commit comments

Comments
 (0)