Skip to content

Commit 61e1ea9

Browse files
Added _contract_iter3 utility to simplify iteration space over 3 arrays
``` In [1]: import dpctl.tensor as dpt, dpctl.tensor._tensor_impl as ti, dpctl In [4]: import itertools In [5]: ti._contract_iter2((2, 5, 3), (15, -3, 1), (0,0,1)) Out[5]: ([10, 3], [3, 1], -12, [0, 1], 0) In [6]: or_s = set( (15*i0 - 3*i1 + i2, i2, 15*i0 - 3*i1 + i2) for i0,i1,i2 in itertools.product(range(2), range(5), range(3)) ) In [7]: alt_s = set( (3*i0 + i1 - 12, i1, 3*i0 + i1 - 12) for i0,i1 in itertools.product(range(10), range(3)) ) In [8]: or_s == alt_s Out[8]: True ```
1 parent 1916370 commit 61e1ea9

File tree

2 files changed

+160
-4
lines changed

2 files changed

+160
-4
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,8 @@ int simplify_iteration_stride(const int nd,
408408
409409
The new shape and new strides, as well as the offset
410410
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
411-
iterating over them will traverse the same pairs of elements, possibly in
412-
different order.
411+
iterating over them will traverse the same set of pairs of elements,
412+
possibly in a different order.
413413
414414
*/
415415
template <class ShapeTy, class StridesTy>
@@ -447,7 +447,7 @@ int simplify_iteration_two_strides(const int nd,
447447
auto str1_p = strides1[p];
448448
auto str2_p = strides2[p];
449449
shape_w.push_back(sh_p);
450-
if (str1_p < 0 && str2_p < 0) {
450+
if (str1_p <= 0 && str2_p <= 0 && std::min(str1_p, str2_p) < 0) {
451451
disp1 += str1_p * (sh_p - 1);
452452
str1_p = -str1_p;
453453
disp2 += str2_p * (sh_p - 1);
@@ -468,7 +468,7 @@ int simplify_iteration_two_strides(const int nd,
468468
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
469469
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
470470

471-
if (jump1 == str1 and jump2 == str2) {
471+
if (jump1 == str1 && jump2 == str2) {
472472
changed = true;
473473
shape_w[i] *= shape_w[i + 1];
474474
for (int j = i; j < nd_; ++j) {
@@ -540,3 +540,149 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
540540
out_strides2.resize(nd);
541541
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2);
542542
}
543+
544+
/*
545+
For purposes of iterating over pairs of elements of three arrays
546+
with `shape` and strides `strides1`, `strides2`, `strides3` given as
547+
pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
548+
strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
549+
may modify memory and returns new length of these arrays.
550+
551+
The new shape and new strides, as well as the offset
552+
`(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
553+
are such that iterating over them will traverse the same set of tuples of
554+
elements, possibly in a different order.
555+
556+
*/
557+
template <class ShapeTy, class StridesTy>
558+
int simplify_iteration_three_strides(const int nd,
559+
ShapeTy *shape,
560+
StridesTy *strides1,
561+
StridesTy *strides2,
562+
StridesTy *strides3,
563+
StridesTy &disp1,
564+
StridesTy &disp2,
565+
StridesTy &disp3)
566+
{
567+
disp1 = std::ptrdiff_t(0);
568+
disp2 = std::ptrdiff_t(0);
569+
if (nd < 2)
570+
return nd;
571+
572+
std::vector<int> pos(nd);
573+
std::iota(pos.begin(), pos.end(), 0);
574+
575+
std::stable_sort(
576+
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
577+
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
578+
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
579+
return (abs_str1 > abs_str2) ||
580+
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
581+
});
582+
583+
std::vector<ShapeTy> shape_w;
584+
std::vector<StridesTy> strides1_w;
585+
std::vector<StridesTy> strides2_w;
586+
std::vector<StridesTy> strides3_w;
587+
588+
bool contractable = true;
589+
for (int i = 0; i < nd; ++i) {
590+
auto p = pos[i];
591+
auto sh_p = shape[p];
592+
auto str1_p = strides1[p];
593+
auto str2_p = strides2[p];
594+
auto str3_p = strides3[p];
595+
shape_w.push_back(sh_p);
596+
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
597+
std::min(std::min(str1_p, str2_p), str3_p) < 0)
598+
{
599+
disp1 += str1_p * (sh_p - 1);
600+
str1_p = -str1_p;
601+
disp2 += str2_p * (sh_p - 1);
602+
str2_p = -str2_p;
603+
disp3 += str3_p * (sh_p - 1);
604+
str3_p = -str3_p;
605+
}
606+
if (str1_p < 0 || str2_p < 0 || str3_p < 0) {
607+
contractable = false;
608+
}
609+
strides1_w.push_back(str1_p);
610+
strides2_w.push_back(str2_p);
611+
strides3_w.push_back(str3_p);
612+
}
613+
int nd_ = nd;
614+
while (contractable) {
615+
bool changed = false;
616+
for (int i = 0; i + 1 < nd_; ++i) {
617+
StridesTy str1 = strides1_w[i + 1];
618+
StridesTy str2 = strides2_w[i + 1];
619+
StridesTy str3 = strides3_w[i + 1];
620+
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
621+
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
622+
StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3;
623+
624+
if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
625+
changed = true;
626+
shape_w[i] *= shape_w[i + 1];
627+
for (int j = i; j < nd_; ++j) {
628+
strides1_w[j] = strides1_w[j + 1];
629+
}
630+
for (int j = i; j < nd_; ++j) {
631+
strides2_w[j] = strides2_w[j + 1];
632+
}
633+
for (int j = i; j < nd_; ++j) {
634+
strides3_w[j] = strides3_w[j + 1];
635+
}
636+
for (int j = i + 1; j + 1 < nd_; ++j) {
637+
shape_w[j] = shape_w[j + 1];
638+
}
639+
--nd_;
640+
break;
641+
}
642+
}
643+
if (!changed)
644+
break;
645+
}
646+
for (int i = 0; i < nd_; ++i) {
647+
shape[i] = shape_w[i];
648+
}
649+
for (int i = 0; i < nd_; ++i) {
650+
strides1[i] = strides1_w[i];
651+
}
652+
for (int i = 0; i < nd_; ++i) {
653+
strides2[i] = strides2_w[i];
654+
}
655+
for (int i = 0; i < nd_; ++i) {
656+
strides3[i] = strides3_w[i];
657+
}
658+
659+
return nd_;
660+
}
661+
662+
template <typename T, class Error, typename vecT = std::vector<T>>
663+
std::tuple<vecT, vecT, T, vecT, T, vecT, T>
664+
contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
665+
{
666+
const size_t dim = shape.size();
667+
if (dim != strides1.size() || dim != strides2.size() ||
668+
dim != strides3.size()) {
669+
throw Error("Shape and strides must be of equal size.");
670+
}
671+
vecT out_shape = shape;
672+
vecT out_strides1 = strides1;
673+
vecT out_strides2 = strides2;
674+
vecT out_strides3 = strides3;
675+
T disp1(0);
676+
T disp2(0);
677+
T disp3(0);
678+
679+
int nd = simplify_iteration_three_strides(
680+
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
681+
out_strides3.data(), disp1, disp2, disp3);
682+
out_shape.resize(nd);
683+
out_strides1.resize(nd);
684+
out_strides2.resize(nd);
685+
out_strides3.resize(nd);
686+
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
687+
out_strides3, disp3);
688+
}

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ PYBIND11_MODULE(_tensor_impl, m)
133133
"as the original "
134134
"iterator, possibly in a different order.");
135135

136+
m.def(
137+
"_contract_iter3", &contract_iter3<py::ssize_t, py::value_error>,
138+
"Simplifies iteration over elements of 3-tuple of arrays of given "
139+
"shape "
140+
"with strides stride1, stride2, and stride3. Returns "
141+
"a 7-tuple: shape, stride and offset for the new iterator of possible "
142+
"smaller dimension for each array, which traverses the same elements "
143+
"as the original "
144+
"iterator, possibly in a different order.");
145+
136146
m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
137147
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
138148
"number of elements using underlying 'C'-contiguous order for flat "

0 commit comments

Comments
 (0)