Skip to content

Commit 2f3af4a

Browse files
committed
Implemented vec_cast in dpctl::tensor::type_utils
Where kernel for contiguous data now uses vec_cast
1 parent 6e8d6ef commit 2f3af4a

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class WhereContigFunctor
8080
T *dst_data = reinterpret_cast<T *>(dst_cp);
8181
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
8282

83-
using dpctl::tensor::type_utils::convert_impl;
84-
8583
using dpctl::tensor::type_utils::is_complex;
8684
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
8785
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
@@ -92,6 +90,7 @@ class WhereContigFunctor
9290
offset < std::min(nelems, base + sgSize * (n_vecs * vec_sz));
9391
offset += sgSize)
9492
{
93+
using dpctl::tensor::type_utils::convert_impl;
9594
bool check = convert_impl<bool, condT>(cond_data[offset]);
9695
dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
9796
}
@@ -115,7 +114,6 @@ class WhereContigFunctor
115114
using cond_ptrT =
116115
sycl::multi_ptr<const condT,
117116
sycl::access::address_space::global_space>;
118-
119117
sycl::vec<T, vec_sz> dst_vec;
120118
sycl::vec<T, vec_sz> x1_vec;
121119
sycl::vec<T, vec_sz> x2_vec;
@@ -127,18 +125,32 @@ class WhereContigFunctor
127125
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_data[idx]));
128126
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_data[idx]));
129127
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_data[idx]));
130-
128+
if constexpr (std::is_same_v<bool, condT>) {
129+
#pragma unroll
130+
for (std::uint8_t k = 0; k < vec_sz; ++k) {
131+
bool check = cond_vec[k];
132+
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
133+
}
134+
}
135+
else {
136+
using dpctl::tensor::type_utils::vec_cast;
137+
sycl::vec<bool, vec_sz> tmp =
138+
vec_cast<bool,
139+
typename decltype(cond_vec)::element_type,
140+
vec_sz>(cond_vec);
131141
#pragma unroll
132-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
133-
bool check = convert_impl<bool, condT>(cond_vec[k]);
134-
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
142+
for (std::uint8_t k = 0; k < vec_sz; ++k) {
143+
bool check = tmp[k];
144+
dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
145+
}
135146
}
136147
sg.store<vec_sz>(dst_ptrT(&dst_data[idx]), dst_vec);
137148
}
138149
}
139150
else {
140151
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
141152
k += sgSize) {
153+
using dpctl::tensor::type_utils::convert_impl;
142154
bool check = convert_impl<bool, condT>(cond_data[k]);
143155
dst_data[k] = check ? x1_data[k] : x2_data[k];
144156
}

dpctl/tensor/libtensor/include/utils/type_utils.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ template <class T> struct is_complex<std::complex<T>> : std::true_type
4141
{
4242
};
4343

44+
template <typename Op, typename Vec, std::size_t... I>
45+
auto vec_cast_impl(const Vec &v, std::index_sequence<I...>)
46+
{
47+
return Op{v[I]...};
48+
}
49+
50+
template <typename dstT,
51+
typename srcT,
52+
std::size_t N,
53+
typename Indices = std::make_index_sequence<N>>
54+
auto vec_cast(const sycl::vec<srcT, N> &s)
55+
{
56+
return vec_cast_impl<sycl::vec<dstT, N>, sycl::vec<srcT, N>>(s, Indices{});
57+
}
58+
4459
template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
4560
{
4661
if constexpr (std::is_same<dstTy, srcTy>::value) {

0 commit comments

Comments
 (0)