@@ -428,11 +428,19 @@ int simplify_iteration_two_strides(const int nd,
428
428
std::iota (pos.begin (), pos.end (), 0 );
429
429
430
430
std::stable_sort (
431
- pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
432
- auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
433
- auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
434
- return (abs_str1 > abs_str2) ||
435
- (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
431
+ pos.begin (), pos.end (), [&strides1, &strides2, &shape](int i1, int i2) {
432
+ auto abs_str1_i1 =
433
+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
434
+ auto abs_str1_i2 =
435
+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
436
+ auto abs_str2_i1 =
437
+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
438
+ auto abs_str2_i2 =
439
+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
440
+ return (abs_str1_i1 > abs_str1_i2) ||
441
+ (abs_str1_i1 == abs_str1_i2 &&
442
+ (abs_str2_i1 > abs_str2_i2 ||
443
+ (abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2])));
436
444
});
437
445
438
446
std::vector<ShapeTy> shape_w;
@@ -458,6 +466,7 @@ int simplify_iteration_two_strides(const int nd,
458
466
strides1_w.push_back (str1_p);
459
467
strides2_w.push_back (str2_p);
460
468
}
469
+
461
470
int nd_ = nd;
462
471
while (contractable) {
463
472
bool changed = false ;
@@ -570,13 +579,28 @@ int simplify_iteration_three_strides(const int nd,
570
579
std::vector<int > pos (nd);
571
580
std::iota (pos.begin (), pos.end (), 0 );
572
581
573
- std::stable_sort (
574
- pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
575
- auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
576
- auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
577
- return (abs_str1 > abs_str2) ||
578
- (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579
- });
582
+ std::stable_sort (pos.begin (), pos.end (),
583
+ [&strides1, &strides2, &strides3, &shape](int i1, int i2) {
584
+ auto abs_str1_i1 =
585
+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
586
+ auto abs_str1_i2 =
587
+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
588
+ auto abs_str2_i1 =
589
+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
590
+ auto abs_str2_i2 =
591
+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
592
+ auto abs_str3_i1 =
593
+ (strides3[i1] < 0 ) ? -strides3[i1] : strides3[i1];
594
+ auto abs_str3_i2 =
595
+ (strides3[i2] < 0 ) ? -strides3[i2] : strides3[i2];
596
+ return (abs_str1_i1 > abs_str1_i2) ||
597
+ ((abs_str1_i1 == abs_str1_i2) &&
598
+ ((abs_str2_i1 > abs_str2_i2) ||
599
+ ((abs_str2_i1 == abs_str2_i2) &&
600
+ ((abs_str3_i1 > abs_str3_i2) ||
601
+ ((abs_str3_i1 == abs_str3_i2) &&
602
+ (shape[i1] > shape[i2]))))));
603
+ });
580
604
581
605
std::vector<ShapeTy> shape_w;
582
606
std::vector<StridesTy> strides1_w;
0 commit comments