Skip to content

Added _contract_iter3 utility to simplify iteration space over 3 arrays #1044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 149 additions & 5 deletions dpctl/tensor/libtensor/include/utils/strided_iters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,8 @@ int simplify_iteration_stride(const int nd,

The new shape and new strides, as well as the offset
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
iterating over them will traverse the same pairs of elements, possibly in
different order.

iterating over them will traverse the same set of pairs of elements,
possibly in a different order.
*/
template <class ShapeTy, class StridesTy>
int simplify_iteration_two_strides(const int nd,
Expand Down Expand Up @@ -447,7 +446,7 @@ int simplify_iteration_two_strides(const int nd,
auto str1_p = strides1[p];
auto str2_p = strides2[p];
shape_w.push_back(sh_p);
if (str1_p < 0 && str2_p < 0) {
if (str1_p <= 0 && str2_p <= 0 && std::min(str1_p, str2_p) < 0) {
disp1 += str1_p * (sh_p - 1);
str1_p = -str1_p;
disp2 += str2_p * (sh_p - 1);
Expand All @@ -468,7 +467,7 @@ int simplify_iteration_two_strides(const int nd,
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;

if (jump1 == str1 and jump2 == str2) {
if (jump1 == str1 && jump2 == str2) {
changed = true;
shape_w[i] *= shape_w[i + 1];
for (int j = i; j < nd_; ++j) {
Expand Down Expand Up @@ -540,3 +539,148 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
out_strides2.resize(nd);
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2);
}

/*
For purposes of iterating over pairs of elements of three arrays
with `shape` and strides `strides1`, `strides2`, `strides3` given as
pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
may modify memory and returns new length of these arrays.

The new shape and new strides, as well as the offset
`(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
are such that iterating over them will traverse the same set of tuples of
elements, possibly in a different order.
*/
template <class ShapeTy, class StridesTy>
int simplify_iteration_three_strides(const int nd,
ShapeTy *shape,
StridesTy *strides1,
StridesTy *strides2,
StridesTy *strides3,
StridesTy &disp1,
StridesTy &disp2,
StridesTy &disp3)
{
disp1 = std::ptrdiff_t(0);
disp2 = std::ptrdiff_t(0);
if (nd < 2)
return nd;

std::vector<int> pos(nd);
std::iota(pos.begin(), pos.end(), 0);

std::stable_sort(
pos.begin(), pos.end(), [&strides1, &shape](int i1, int i2) {
auto abs_str1 = (strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
auto abs_str2 = (strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
return (abs_str1 > abs_str2) ||
(abs_str1 == abs_str2 && shape[i1] > shape[i2]);
});

std::vector<ShapeTy> shape_w;
std::vector<StridesTy> strides1_w;
std::vector<StridesTy> strides2_w;
std::vector<StridesTy> strides3_w;

bool contractable = true;
for (int i = 0; i < nd; ++i) {
auto p = pos[i];
auto sh_p = shape[p];
auto str1_p = strides1[p];
auto str2_p = strides2[p];
auto str3_p = strides3[p];
shape_w.push_back(sh_p);
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
std::min(std::min(str1_p, str2_p), str3_p) < 0)
{
disp1 += str1_p * (sh_p - 1);
str1_p = -str1_p;
disp2 += str2_p * (sh_p - 1);
str2_p = -str2_p;
disp3 += str3_p * (sh_p - 1);
str3_p = -str3_p;
}
if (str1_p < 0 || str2_p < 0 || str3_p < 0) {
contractable = false;
}
strides1_w.push_back(str1_p);
strides2_w.push_back(str2_p);
strides3_w.push_back(str3_p);
}
int nd_ = nd;
while (contractable) {
bool changed = false;
for (int i = 0; i + 1 < nd_; ++i) {
StridesTy str1 = strides1_w[i + 1];
StridesTy str2 = strides2_w[i + 1];
StridesTy str3 = strides3_w[i + 1];
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3;

if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
changed = true;
shape_w[i] *= shape_w[i + 1];
for (int j = i; j < nd_; ++j) {
strides1_w[j] = strides1_w[j + 1];
}
for (int j = i; j < nd_; ++j) {
strides2_w[j] = strides2_w[j + 1];
}
for (int j = i; j < nd_; ++j) {
strides3_w[j] = strides3_w[j + 1];
}
for (int j = i + 1; j + 1 < nd_; ++j) {
shape_w[j] = shape_w[j + 1];
}
--nd_;
break;
}
}
if (!changed)
break;
}
for (int i = 0; i < nd_; ++i) {
shape[i] = shape_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides1[i] = strides1_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides2[i] = strides2_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides3[i] = strides3_w[i];
}

return nd_;
}

template <typename T, class Error, typename vecT = std::vector<T>>
std::tuple<vecT, vecT, T, vecT, T, vecT, T>
contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
{
const size_t dim = shape.size();
if (dim != strides1.size() || dim != strides2.size() ||
dim != strides3.size()) {
throw Error("Shape and strides must be of equal size.");
}
vecT out_shape = shape;
vecT out_strides1 = strides1;
vecT out_strides2 = strides2;
vecT out_strides3 = strides3;
T disp1(0);
T disp2(0);
T disp3(0);

int nd = simplify_iteration_three_strides(
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
out_strides3.data(), disp1, disp2, disp3);
out_shape.resize(nd);
out_strides1.resize(nd);
out_strides2.resize(nd);
out_strides3.resize(nd);
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
out_strides3, disp3);
}
10 changes: 10 additions & 0 deletions dpctl/tensor/libtensor/source/tensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ PYBIND11_MODULE(_tensor_impl, m)
"as the original "
"iterator, possibly in a different order.");

m.def(
"_contract_iter3", &contract_iter3<py::ssize_t, py::value_error>,
"Simplifies iteration over elements of 3-tuple of arrays of given "
"shape "
"with strides stride1, stride2, and stride3. Returns "
"a 7-tuple: shape, stride and offset for the new iterator of possible "
"smaller dimension for each array, which traverses the same elements "
"as the original "
"iterator, possibly in a different order.");

m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
"number of elements using underlying 'C'-contiguous order for flat "
Expand Down