Skip to content

Commit 9e37ecb

Browse files
Merge pull request #1054 from IntelPython/enhance-iter-space-simplification
Changed simplify_iteration_two_strides, simplify_iteration_three_strides
2 parents 5271fe1 + 078493e commit 9e37ecb

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd,
428428
std::iota(pos.begin(), pos.end(), 0);
429429

430430
std::stable_sort(
431-
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
432-
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
433-
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
434-
return (abs_str1 > abs_str2) ||
435-
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
431+
pos.begin(), pos.end(), [&strides1, &strides2, &shape](int i1, int i2) {
432+
auto abs_str1_i1 =
433+
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
434+
auto abs_str1_i2 =
435+
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
436+
auto abs_str2_i1 =
437+
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
438+
auto abs_str2_i2 =
439+
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
440+
return (abs_str1_i1 > abs_str1_i2) ||
441+
(abs_str1_i1 == abs_str1_i2 &&
442+
(abs_str2_i1 > abs_str2_i2 ||
443+
(abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2])));
436444
});
437445

438446
std::vector<ShapeTy> shape_w;
@@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd,
458466
strides1_w.push_back(str1_p);
459467
strides2_w.push_back(str2_p);
460468
}
469+
461470
int nd_ = nd;
462471
while (contractable) {
463472
bool changed = false;
@@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd,
570579
std::vector<int> pos(nd);
571580
std::iota(pos.begin(), pos.end(), 0);
572581

573-
std::stable_sort(
574-
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
575-
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
576-
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
577-
return (abs_str1 > abs_str2) ||
578-
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579-
});
582+
std::stable_sort(pos.begin(), pos.end(),
583+
[&strides1, &strides2, &strides3, &shape](int i1, int i2) {
584+
auto abs_str1_i1 =
585+
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
586+
auto abs_str1_i2 =
587+
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
588+
auto abs_str2_i1 =
589+
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
590+
auto abs_str2_i2 =
591+
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
592+
auto abs_str3_i1 =
593+
(strides3[i1] < 0) ? -strides3[i1] : strides3[i1];
594+
auto abs_str3_i2 =
595+
(strides3[i2] < 0) ? -strides3[i2] : strides3[i2];
596+
return (abs_str1_i1 > abs_str1_i2) ||
597+
((abs_str1_i1 == abs_str1_i2) &&
598+
((abs_str2_i1 > abs_str2_i2) ||
599+
((abs_str2_i1 == abs_str2_i2) &&
600+
((abs_str3_i1 > abs_str3_i2) ||
601+
((abs_str3_i1 == abs_str3_i2) &&
602+
(shape[i1] > shape[i2]))))));
603+
});
580604

581605
std::vector<ShapeTy> shape_w;
582606
std::vector<StridesTy> strides1_w;

0 commit comments

Comments
 (0)