@@ -80,8 +80,6 @@ class WhereContigFunctor
80
80
T *dst_data = reinterpret_cast <T *>(dst_cp);
81
81
const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
82
82
83
- using dpctl::tensor::type_utils::convert_impl;
84
-
85
83
using dpctl::tensor::type_utils::is_complex;
86
84
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
87
85
std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
@@ -92,6 +90,7 @@ class WhereContigFunctor
92
90
offset < std::min (nelems, base + sgSize * (n_vecs * vec_sz));
93
91
offset += sgSize)
94
92
{
93
+ using dpctl::tensor::type_utils::convert_impl;
95
94
bool check = convert_impl<bool , condT>(cond_data[offset]);
96
95
dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
97
96
}
@@ -115,7 +114,6 @@ class WhereContigFunctor
115
114
using cond_ptrT =
116
115
sycl::multi_ptr<const condT,
117
116
sycl::access::address_space::global_space>;
118
-
119
117
sycl::vec<T, vec_sz> dst_vec;
120
118
sycl::vec<T, vec_sz> x1_vec;
121
119
sycl::vec<T, vec_sz> x2_vec;
@@ -127,18 +125,32 @@ class WhereContigFunctor
127
125
x1_vec = sg.load <vec_sz>(x_ptrT (&x1_data[idx]));
128
126
x2_vec = sg.load <vec_sz>(x_ptrT (&x2_data[idx]));
129
127
cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_data[idx]));
130
-
128
+ if constexpr (std::is_same_v<bool , condT>) {
129
+ #pragma unroll
130
+ for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
131
+ bool check = cond_vec[k];
132
+ dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
133
+ }
134
+ }
135
+ else {
136
+ using dpctl::tensor::type_utils::vec_cast;
137
+ sycl::vec<bool , vec_sz> tmp =
138
+ vec_cast<bool ,
139
+ typename decltype (cond_vec)::element_type,
140
+ vec_sz>(cond_vec);
131
141
#pragma unroll
132
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
133
- bool check = convert_impl<bool , condT>(cond_vec[k]);
134
- dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
142
+ for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
143
+ bool check = tmp[k];
144
+ dst_vec[k] = check ? x1_vec[k] : x2_vec[k];
145
+ }
135
146
}
136
147
sg.store <vec_sz>(dst_ptrT (&dst_data[idx]), dst_vec);
137
148
}
138
149
}
139
150
else {
140
151
for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
141
152
k += sgSize) {
153
+ using dpctl::tensor::type_utils::convert_impl;
142
154
bool check = convert_impl<bool , condT>(cond_data[k]);
143
155
dst_data[k] = check ? x1_data[k] : x2_data[k];
144
156
}
0 commit comments