Skip to content

Commit 1fbe04e

Browse files
Changed simplify_iteration_two_strides, simplify_iteration_three_strides
The target ordering used to be based on absolute values of the first vector of strides, now it uses lexicographic ordering of tuples of absolute values of all strides involved. This enables iteration space reduction for examples where all strides in the first vector are all zero, like in the example arising from ``` dpctl.tensor.full((2,3,4,), dpctl.tensor.asarray(1)) ``` The following two invocations show that iteration space used is 1d: ``` onetrace -d -v --demangle python -c "import dpctl.tensor as dpt; x = dpt.ones((30, 40, 50), dtype='i4'); y = dpt.empty_like(x, dtype='f4'); print((x.flags, y.flags)); y[:] = x" onetrace -d -v --demangle python -c "import dpctl.tensor._tensor_impl as ti, dpctl.tensor as dpt; dpt.full((2,3,4), dpt.asarray(1, dtype='f4'))" ```
1 parent db8fe2f commit 1fbe04e

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd,
428428
std::iota(pos.begin(), pos.end(), 0);
429429

430430
std::stable_sort(
431-
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
432-
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
433-
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
434-
return (abs_str1 > abs_str2) ||
435-
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
431+
pos.begin(), pos.end(), [&strides1, &strides2, &shape](int i1, int i2) {
432+
auto abs_str1_i1 =
433+
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
434+
auto abs_str1_i2 =
435+
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
436+
auto abs_str2_i1 =
437+
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
438+
auto abs_str2_i2 =
439+
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
440+
return (abs_str1_i1 > abs_str1_i2) ||
441+
(abs_str1_i1 == abs_str1_i2 &&
442+
(abs_str2_i1 > abs_str2_i2 ||
443+
(abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2])));
436444
});
437445

438446
std::vector<ShapeTy> shape_w;
@@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd,
458466
strides1_w.push_back(str1_p);
459467
strides2_w.push_back(str2_p);
460468
}
469+
461470
int nd_ = nd;
462471
while (contractable) {
463472
bool changed = false;
@@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd,
570579
std::vector<int> pos(nd);
571580
std::iota(pos.begin(), pos.end(), 0);
572581

573-
std::stable_sort(
574-
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
575-
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
576-
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
577-
return (abs_str1 > abs_str2) ||
578-
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579-
});
582+
std::stable_sort(pos.begin(), pos.end(),
583+
[&strides1, &strides2, &strides3, &shape](int i1, int i2) {
584+
auto abs_str1_i1 =
585+
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
586+
auto abs_str1_i2 =
587+
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
588+
auto abs_str2_i1 =
589+
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
590+
auto abs_str2_i2 =
591+
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
592+
auto abs_str3_i1 =
593+
(strides3[i1] < 0) ? -strides3[i1] : strides3[i1];
594+
auto abs_str3_i2 =
595+
(strides3[i2] < 0) ? -strides3[i2] : strides3[i2];
596+
return (abs_str1_i1 > abs_str1_i2) ||
597+
((abs_str1_i1 == abs_str1_i2) &&
598+
((abs_str2_i1 > abs_str2_i2) ||
599+
((abs_str2_i1 == abs_str2_i2) &&
600+
((abs_str3_i1 > abs_str3_i2) ||
601+
((abs_str3_i1 == abs_str3_i2) &&
602+
(shape[i1] > shape[i2]))))));
603+
});
580604

581605
std::vector<ShapeTy> shape_w;
582606
std::vector<StridesTy> strides1_w;

0 commit comments

Comments
 (0)