Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dpctl/tensor/libtensor/include/kernels/where.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ class WhereStridedFunctor
bool check =
convert_impl<bool, condT>(cond_p[offsets.get_first_offset()]);

dst_p[gid] = check ? x1_p[offsets.get_second_offset()]
: x2_p[offsets.get_third_offset()];
dst_p[offsets.get_fourth_offset()] =
check ? x1_p[offsets.get_second_offset()]
: x2_p[offsets.get_third_offset()];
}
};

Expand All @@ -227,6 +228,7 @@ typedef sycl::event (*where_strided_impl_fn_ptr_t)(
py::ssize_t,
py::ssize_t,
py::ssize_t,
py::ssize_t,
const std::vector<sycl::event> &);

template <typename T, typename condT>
Expand All @@ -241,6 +243,7 @@ sycl::event where_strided_impl(sycl::queue q,
py::ssize_t x1_offset,
py::ssize_t x2_offset,
py::ssize_t cond_offset,
py::ssize_t dst_offset,
const std::vector<sycl::event> &depends)
{
const condT *cond_tp = reinterpret_cast<const condT *>(cond_cp);
Expand All @@ -251,13 +254,13 @@ sycl::event where_strided_impl(sycl::queue q,
sycl::event where_ev = q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

ThreeOffsets_StridedIndexer indexer{nd, cond_offset, x1_offset,
x2_offset, shape_strides};
FourOffsets_StridedIndexer indexer{
nd, cond_offset, x1_offset, x2_offset, dst_offset, shape_strides};

cgh.parallel_for<
where_strided_kernel<T, condT, ThreeOffsets_StridedIndexer>>(
where_strided_kernel<T, condT, FourOffsets_StridedIndexer>>(
sycl::range<1>(nelems),
WhereStridedFunctor<T, condT, ThreeOffsets_StridedIndexer>(
WhereStridedFunctor<T, condT, FourOffsets_StridedIndexer>(
cond_tp, x1_tp, x2_tp, dst_tp, indexer));
});

Expand Down
109 changes: 109 additions & 0 deletions dpctl/tensor/libtensor/include/utils/offset_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,115 @@ struct ThreeZeroOffsets_Indexer
}
};

template <typename displacementT> struct FourOffsets
{
FourOffsets()
: first_offset(0), second_offset(0), third_offset(0), fourth_offset(0)
{
}
FourOffsets(const displacementT &first_offset_,
const displacementT &second_offset_,
const displacementT &third_offset_,
const displacementT &fourth_offset_)
: first_offset(first_offset_), second_offset(second_offset_),
third_offset(third_offset_), fourth_offset(fourth_offset_)
{
}

displacementT get_first_offset() const
{
return first_offset;
}
displacementT get_second_offset() const
{
return second_offset;
}
displacementT get_third_offset() const
{
return third_offset;
}
displacementT get_fourth_offset() const
{
return fourth_offset;
}

private:
displacementT first_offset = 0;
displacementT second_offset = 0;
displacementT third_offset = 0;
displacementT fourth_offset = 0;
};

struct FourOffsets_StridedIndexer
{
FourOffsets_StridedIndexer(int common_nd,
py::ssize_t first_offset_,
py::ssize_t second_offset_,
py::ssize_t third_offset_,
py::ssize_t fourth_offset_,
py::ssize_t const *_packed_shape_strides)
: nd(common_nd), starting_first_offset(first_offset_),
starting_second_offset(second_offset_),
starting_third_offset(third_offset_),
starting_fourth_offset(fourth_offset_),
shape_strides(_packed_shape_strides)
{
}

FourOffsets<py::ssize_t> operator()(py::ssize_t gid) const
{
return compute_offsets(gid);
}

FourOffsets<py::ssize_t> operator()(size_t gid) const
{
return compute_offsets(static_cast<py::ssize_t>(gid));
}

private:
int nd;
py::ssize_t starting_first_offset;
py::ssize_t starting_second_offset;
py::ssize_t starting_third_offset;
py::ssize_t starting_fourth_offset;
py::ssize_t const *shape_strides;

FourOffsets<py::ssize_t> compute_offsets(py::ssize_t gid) const
{
using dpctl::tensor::strides::CIndexer_vector;

CIndexer_vector _ind(nd);
py::ssize_t relative_first_offset(0);
py::ssize_t relative_second_offset(0);
py::ssize_t relative_third_offset(0);
py::ssize_t relative_fourth_offset(0);
_ind.get_displacement<const py::ssize_t *, const py::ssize_t *>(
gid,
shape_strides, // shape ptr
shape_strides + nd, // strides ptr
shape_strides + 2 * nd, // strides ptr
shape_strides + 3 * nd, // strides ptr
shape_strides + 4 * nd, // strides ptr
relative_first_offset, relative_second_offset,
relative_third_offset, relative_fourth_offset);
return FourOffsets<py::ssize_t>(
starting_first_offset + relative_first_offset,
starting_second_offset + relative_second_offset,
starting_third_offset + relative_third_offset,
starting_fourth_offset + relative_fourth_offset);
}
};

struct FourZeroOffsets_Indexer
{
FourZeroOffsets_Indexer() {}

FourOffsets<py::ssize_t> operator()(py::ssize_t) const
{
return FourOffsets<py::ssize_t>();
}
};

struct NthStrideOffset
{
NthStrideOffset(int common_nd,
Expand Down
194 changes: 193 additions & 1 deletion dpctl/tensor/libtensor/include/utils/strided_iters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ int simplify_iteration_three_strides(const int nd,
auto str3_p = strides3[p];
shape_w.push_back(sh_p);
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
std::min(std::min(str1_p, str2_p), str3_p) < 0)
std::min({str1_p, str2_p, str3_p}) < 0)
{
disp1 += str1_p * (sh_p - 1);
str1_p = -str1_p;
Expand Down Expand Up @@ -716,6 +716,198 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
out_strides3, disp3);
}

/*
For purposes of iterating over pairs of elements of four arrays
with `shape` and strides `strides1`, `strides2`, `strides3`,
`strides4` given as pointers `simplify_iteration_four_strides(nd,
shape_ptr, strides1_ptr, strides2_ptr, strides3_ptr, strides4_ptr,
disp1, disp2, disp3, disp4)` may modify memory and returns new
length of these arrays.

The new shape and new strides, as well as the offset
`(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3,
new_stride4, disp4)` are such that iterating over them will traverse the
same set of tuples of elements, possibly in a different order.
*/
template <class ShapeTy, class StridesTy>
int simplify_iteration_four_strides(const int nd,
ShapeTy *shape,
StridesTy *strides1,
StridesTy *strides2,
StridesTy *strides3,
StridesTy *strides4,
StridesTy &disp1,
StridesTy &disp2,
StridesTy &disp3,
StridesTy &disp4)
{
disp1 = std::ptrdiff_t(0);
disp2 = std::ptrdiff_t(0);
if (nd < 2)
return nd;

std::vector<int> pos(nd);
std::iota(pos.begin(), pos.end(), 0);

std::stable_sort(
pos.begin(), pos.end(),
[&strides1, &strides2, &strides3, &strides4, &shape](int i1, int i2) {
auto abs_str1_i1 =
(strides1[i1] < 0) ? -strides1[i1] : strides1[i1];
auto abs_str1_i2 =
(strides1[i2] < 0) ? -strides1[i2] : strides1[i2];
auto abs_str2_i1 =
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
auto abs_str2_i2 =
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
auto abs_str3_i1 =
(strides3[i1] < 0) ? -strides3[i1] : strides3[i1];
auto abs_str3_i2 =
(strides3[i2] < 0) ? -strides3[i2] : strides3[i2];
auto abs_str4_i1 =
(strides4[i1] < 0) ? -strides4[i1] : strides4[i1];
auto abs_str4_i2 =
(strides4[i2] < 0) ? -strides4[i2] : strides4[i2];
return (abs_str1_i1 > abs_str1_i2) ||
((abs_str1_i1 == abs_str1_i2) &&
((abs_str2_i1 > abs_str2_i2) ||
((abs_str2_i1 == abs_str2_i2) &&
((abs_str3_i1 > abs_str3_i2) ||
((abs_str3_i1 == abs_str3_i2) &&
((abs_str4_i1 > abs_str4_i2) ||
((abs_str4_i1 == abs_str4_i2) &&
(shape[i1] > shape[i2]))))))));
});

std::vector<ShapeTy> shape_w;
std::vector<StridesTy> strides1_w;
std::vector<StridesTy> strides2_w;
std::vector<StridesTy> strides3_w;
std::vector<StridesTy> strides4_w;

bool contractable = true;
for (int i = 0; i < nd; ++i) {
auto p = pos[i];
auto sh_p = shape[p];
auto str1_p = strides1[p];
auto str2_p = strides2[p];
auto str3_p = strides3[p];
auto str4_p = strides4[p];
shape_w.push_back(sh_p);
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && str4_p <= 0 &&
std::min({str1_p, str2_p, str3_p, str4_p}) < 0)
{
disp1 += str1_p * (sh_p - 1);
str1_p = -str1_p;
disp2 += str2_p * (sh_p - 1);
str2_p = -str2_p;
disp3 += str3_p * (sh_p - 1);
str3_p = -str3_p;
disp4 += str4_p * (sh_p - 1);
str4_p = -str4_p;
}
if (str1_p < 0 || str2_p < 0 || str3_p < 0 || str4_p < 0) {
contractable = false;
}
strides1_w.push_back(str1_p);
strides2_w.push_back(str2_p);
strides3_w.push_back(str3_p);
strides4_w.push_back(str4_p);
}
int nd_ = nd;
while (contractable) {
bool changed = false;
for (int i = 0; i + 1 < nd_; ++i) {
StridesTy str1 = strides1_w[i + 1];
StridesTy str2 = strides2_w[i + 1];
StridesTy str3 = strides3_w[i + 1];
StridesTy str4 = strides4_w[i + 1];
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1] - 1) * str1;
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1] - 1) * str2;
StridesTy jump3 = strides3_w[i] - (shape_w[i + 1] - 1) * str3;
StridesTy jump4 = strides4_w[i] - (shape_w[i + 1] - 1) * str4;

if (jump1 == str1 && jump2 == str2 && jump3 == str3 &&
jump4 == str4) {
changed = true;
shape_w[i] *= shape_w[i + 1];
for (int j = i; j < nd_; ++j) {
strides1_w[j] = strides1_w[j + 1];
}
for (int j = i; j < nd_; ++j) {
strides2_w[j] = strides2_w[j + 1];
}
for (int j = i; j < nd_; ++j) {
strides3_w[j] = strides3_w[j + 1];
}
for (int j = i; j < nd_; ++j) {
strides4_w[j] = strides4_w[j + 1];
}
for (int j = i + 1; j + 1 < nd_; ++j) {
shape_w[j] = shape_w[j + 1];
}
--nd_;
break;
}
}
if (!changed)
break;
}
for (int i = 0; i < nd_; ++i) {
shape[i] = shape_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides1[i] = strides1_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides2[i] = strides2_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides3[i] = strides3_w[i];
}
for (int i = 0; i < nd_; ++i) {
strides4[i] = strides4_w[i];
}

return nd_;
}

template <typename T, class Error, typename vecT = std::vector<T>>
std::tuple<vecT, vecT, T, vecT, T, vecT, T, vecT, T>
contract_iter4(vecT shape,
vecT strides1,
vecT strides2,
vecT strides3,
vecT strides4)
{
const size_t dim = shape.size();
if (dim != strides1.size() || dim != strides2.size() ||
dim != strides3.size() || dim != strides4.size())
{
throw Error("Shape and strides must be of equal size.");
}
vecT out_shape = shape;
vecT out_strides1 = strides1;
vecT out_strides2 = strides2;
vecT out_strides3 = strides3;
vecT out_strides4 = strides4;
T disp1(0);
T disp2(0);
T disp3(0);
T disp4(0);

int nd = simplify_iteration_four_strides(
dim, out_shape.data(), out_strides1.data(), out_strides2.data(),
out_strides3.data(), out_strides4.data(), disp1, disp2, disp3, disp4);
out_shape.resize(nd);
out_strides1.resize(nd);
out_strides2.resize(nd);
out_strides3.resize(nd);
out_strides4.resize(nd);
return std::make_tuple(out_shape, out_strides1, disp1, out_strides2, disp2,
out_strides3, disp3, out_strides4, disp4);
}

} // namespace strides
} // namespace tensor
} // namespace dpctl
Loading