@@ -408,8 +408,8 @@ int simplify_iteration_stride(const int nd,
408
408
409
409
The new shape and new strides, as well as the offset
410
410
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
411
- iterating over them will traverse the same pairs of elements, possibly in
412
- different order.
411
+ iterating over them will traverse the same set of pairs of elements,
412
+ possibly in a different order.
413
413
414
414
*/
415
415
template <class ShapeTy , class StridesTy >
@@ -447,7 +447,7 @@ int simplify_iteration_two_strides(const int nd,
447
447
auto str1_p = strides1[p];
448
448
auto str2_p = strides2[p];
449
449
shape_w.push_back (sh_p);
450
- if (str1_p < 0 && str2_p < 0 ) {
450
+ if (str1_p <= 0 && str2_p <= 0 && std::min (str1_p, str2_p) < 0 ) {
451
451
disp1 += str1_p * (sh_p - 1 );
452
452
str1_p = -str1_p;
453
453
disp2 += str2_p * (sh_p - 1 );
@@ -468,7 +468,7 @@ int simplify_iteration_two_strides(const int nd,
468
468
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
469
469
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
470
470
471
- if (jump1 == str1 and jump2 == str2) {
471
+ if (jump1 == str1 && jump2 == str2) {
472
472
changed = true ;
473
473
shape_w[i] *= shape_w[i + 1 ];
474
474
for (int j = i; j < nd_; ++j) {
@@ -540,3 +540,149 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
540
540
out_strides2.resize (nd);
541
541
return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2);
542
542
}
543
+
544
+ /*
545
+ For purposes of iterating over pairs of elements of three arrays
546
+ with `shape` and strides `strides1`, `strides2`, `strides3` given as
547
+ pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
548
+ strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
549
+ may modify memory and returns new length of these arrays.
550
+
551
+ The new shape and new strides, as well as the offset
552
+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
553
+ are such that iterating over them will traverse the same set of tuples of
554
+ elements, possibly in a different order.
555
+
556
+ */
557
+ template <class ShapeTy , class StridesTy >
558
+ int simplify_iteration_three_strides (const int nd,
559
+ ShapeTy *shape,
560
+ StridesTy *strides1,
561
+ StridesTy *strides2,
562
+ StridesTy *strides3,
563
+ StridesTy &disp1,
564
+ StridesTy &disp2,
565
+ StridesTy &disp3)
566
+ {
567
+ disp1 = std::ptrdiff_t (0 );
568
+ disp2 = std::ptrdiff_t (0 );
569
+ if (nd < 2 )
570
+ return nd;
571
+
572
+ std::vector<int > pos (nd);
573
+ std::iota (pos.begin (), pos.end (), 0 );
574
+
575
+ std::stable_sort (
576
+ pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
577
+ auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
578
+ auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
579
+ return (abs_str1 > abs_str2) ||
580
+ (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
581
+ });
582
+
583
+ std::vector<ShapeTy> shape_w;
584
+ std::vector<StridesTy> strides1_w;
585
+ std::vector<StridesTy> strides2_w;
586
+ std::vector<StridesTy> strides3_w;
587
+
588
+ bool contractable = true ;
589
+ for (int i = 0 ; i < nd; ++i) {
590
+ auto p = pos[i];
591
+ auto sh_p = shape[p];
592
+ auto str1_p = strides1[p];
593
+ auto str2_p = strides2[p];
594
+ auto str3_p = strides3[p];
595
+ shape_w.push_back (sh_p);
596
+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
597
+ std::min (std::min (str1_p, str2_p), str3_p) < 0 )
598
+ {
599
+ disp1 += str1_p * (sh_p - 1 );
600
+ str1_p = -str1_p;
601
+ disp2 += str2_p * (sh_p - 1 );
602
+ str2_p = -str2_p;
603
+ disp3 += str3_p * (sh_p - 1 );
604
+ str3_p = -str3_p;
605
+ }
606
+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 ) {
607
+ contractable = false ;
608
+ }
609
+ strides1_w.push_back (str1_p);
610
+ strides2_w.push_back (str2_p);
611
+ strides3_w.push_back (str3_p);
612
+ }
613
+ int nd_ = nd;
614
+ while (contractable) {
615
+ bool changed = false ;
616
+ for (int i = 0 ; i + 1 < nd_; ++i) {
617
+ StridesTy str1 = strides1_w[i + 1 ];
618
+ StridesTy str2 = strides2_w[i + 1 ];
619
+ StridesTy str3 = strides3_w[i + 1 ];
620
+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
621
+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
622
+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
623
+
624
+ if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
625
+ changed = true ;
626
+ shape_w[i] *= shape_w[i + 1 ];
627
+ for (int j = i; j < nd_; ++j) {
628
+ strides1_w[j] = strides1_w[j + 1 ];
629
+ }
630
+ for (int j = i; j < nd_; ++j) {
631
+ strides2_w[j] = strides2_w[j + 1 ];
632
+ }
633
+ for (int j = i; j < nd_; ++j) {
634
+ strides3_w[j] = strides3_w[j + 1 ];
635
+ }
636
+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
637
+ shape_w[j] = shape_w[j + 1 ];
638
+ }
639
+ --nd_;
640
+ break ;
641
+ }
642
+ }
643
+ if (!changed)
644
+ break ;
645
+ }
646
+ for (int i = 0 ; i < nd_; ++i) {
647
+ shape[i] = shape_w[i];
648
+ }
649
+ for (int i = 0 ; i < nd_; ++i) {
650
+ strides1[i] = strides1_w[i];
651
+ }
652
+ for (int i = 0 ; i < nd_; ++i) {
653
+ strides2[i] = strides2_w[i];
654
+ }
655
+ for (int i = 0 ; i < nd_; ++i) {
656
+ strides3[i] = strides3_w[i];
657
+ }
658
+
659
+ return nd_;
660
+ }
661
+
662
+ template <typename T, class Error , typename vecT = std::vector<T>>
663
+ std::tuple<vecT, vecT, T, vecT, T, vecT, T>
664
+ contract_iter3 (vecT shape, vecT strides1, vecT strides2, vecT strides3)
665
+ {
666
+ const size_t dim = shape.size ();
667
+ if (dim != strides1.size () || dim != strides2.size () ||
668
+ dim != strides3.size ()) {
669
+ throw Error (" Shape and strides must be of equal size." );
670
+ }
671
+ vecT out_shape = shape;
672
+ vecT out_strides1 = strides1;
673
+ vecT out_strides2 = strides2;
674
+ vecT out_strides3 = strides3;
675
+ T disp1 (0 );
676
+ T disp2 (0 );
677
+ T disp3 (0 );
678
+
679
+ int nd = simplify_iteration_three_strides (
680
+ dim, out_shape.data (), out_strides1.data (), out_strides2.data (),
681
+ out_strides3.data (), disp1, disp2, disp3);
682
+ out_shape.resize (nd);
683
+ out_strides1.resize (nd);
684
+ out_strides2.resize (nd);
685
+ out_strides3.resize (nd);
686
+ return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2,
687
+ out_strides3, disp3);
688
+ }
0 commit comments