Skip to content

Commit 2e0bb63

Browse files
committed
Vec_cast removed from where functor
- Casting from char pointer to typed pointer moved into functor constructors
1 parent ee4627a commit 2e0bb63

File tree

1 file changed

+42
-61
lines changed
  • dpctl/tensor/libtensor/include/kernels

1 file changed

+42
-61
lines changed

dpctl/tensor/libtensor/include/kernels/where.hpp

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -57,29 +57,26 @@ class WhereContigFunctor
5757
{
5858
private:
5959
size_t nelems = 0;
60-
const char *x1_cp = nullptr;
61-
const char *x2_cp = nullptr;
62-
char *dst_cp = nullptr;
63-
const char *cond_cp = nullptr;
60+
const T *x1_p = nullptr;
61+
const T *x2_p = nullptr;
62+
T *dst_p = nullptr;
63+
const condT *cond_p = nullptr;
6464

6565
public:
6666
WhereContigFunctor(size_t nelems_,
67-
const char *cond_data_p,
68-
const char *x1_data_p,
69-
const char *x2_data_p,
70-
char *dst_data_p)
71-
: nelems(nelems_), x1_cp(x1_data_p), x2_cp(x2_data_p),
72-
dst_cp(dst_data_p), cond_cp(cond_data_p)
67+
const char *cond_cp,
68+
const char *x1_cp,
69+
const char *x2_cp,
70+
char *dst_cp)
71+
: nelems(nelems_), x1_p(reinterpret_cast<const T *>(x1_cp)),
72+
x2_p(reinterpret_cast<const T *>(x2_cp)),
73+
dst_p(reinterpret_cast<T *>(dst_cp)),
74+
cond_p(reinterpret_cast<const condT *>(cond_cp))
7375
{
7476
}
7577

7678
void operator()(sycl::nd_item<1> ndit) const
7779
{
78-
const T *x1_data = reinterpret_cast<const T *>(x1_cp);
79-
const T *x2_data = reinterpret_cast<const T *>(x2_cp);
80-
T *dst_data = reinterpret_cast<T *>(dst_cp);
81-
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
82-
8380
using dpctl::tensor::type_utils::is_complex;
8481
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
8582
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
@@ -91,8 +88,8 @@ class WhereContigFunctor
9188
offset += sgSize)
9289
{
9390
using dpctl::tensor::type_utils::convert_impl;
94-
bool check = convert_impl<bool, condT>(cond_data[offset]);
95-
dst_data[offset] = check ? x1_data[offset] : x2_data[offset];
91+
bool check = convert_impl<bool, condT>(cond_p[offset]);
92+
dst_p[offset] = check ? x1_p[offset] : x2_p[offset];
9693
}
9794
}
9895
else {
@@ -122,33 +119,20 @@ class WhereContigFunctor
122119
#pragma unroll
123120
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
124121
auto idx = base + it * sgSize;
125-
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_data[idx]));
126-
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_data[idx]));
127-
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_data[idx]));
128-
if constexpr (std::is_same_v<bool, condT>) {
129-
#pragma unroll
130-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
131-
dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
132-
}
133-
}
134-
else {
135-
using dpctl::tensor::type_utils::vec_cast;
136-
sycl::vec<bool, vec_sz> tmp =
137-
vec_cast<bool,
138-
typename decltype(cond_vec)::element_type,
139-
vec_sz>(cond_vec);
122+
x1_vec = sg.load<vec_sz>(x_ptrT(&x1_p[idx]));
123+
x2_vec = sg.load<vec_sz>(x_ptrT(&x2_p[idx]));
124+
cond_vec = sg.load<vec_sz>(cond_ptrT(&cond_p[idx]));
140125
#pragma unroll
141-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
142-
dst_vec[k] = tmp[k] ? x1_vec[k] : x2_vec[k];
143-
}
126+
for (std::uint8_t k = 0; k < vec_sz; ++k) {
127+
dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
144128
}
145-
sg.store<vec_sz>(dst_ptrT(&dst_data[idx]), dst_vec);
129+
sg.store<vec_sz>(dst_ptrT(&dst_p[idx]), dst_vec);
146130
}
147131
}
148132
else {
149133
for (size_t k = base + sg.get_local_id()[0]; k < nelems;
150134
k += sgSize) {
151-
dst_data[k] = cond_data[k] ? x1_data[k] : x2_data[k];
135+
dst_p[k] = cond_p[k] ? x1_p[k] : x2_p[k];
152136
}
153137
}
154138
}
@@ -167,10 +151,10 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)(
167151
template <typename T, typename condT>
168152
sycl::event where_contig_impl(sycl::queue q,
169153
size_t nelems,
170-
const char *cond_p,
171-
const char *x1_p,
172-
const char *x2_p,
173-
char *dst_p,
154+
const char *cond_cp,
155+
const char *x1_cp,
156+
const char *x2_cp,
157+
char *dst_cp,
174158
const std::vector<sycl::event> &depends)
175159
{
176160
sycl::event where_ev = q.submit([&](sycl::handler &cgh) {
@@ -186,8 +170,8 @@ sycl::event where_contig_impl(sycl::queue q,
186170

187171
cgh.parallel_for<where_contig_kernel<T, condT, vec_sz, n_vecs>>(
188172
sycl::nd_range<1>(gws_range, lws_range),
189-
WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_p, x1_p,
190-
x2_p, dst_p));
173+
WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_cp, x1_cp,
174+
x2_cp, dst_cp));
191175
});
192176

193177
return where_ev;
@@ -197,39 +181,36 @@ template <typename T, typename condT, typename IndexerT>
197181
class WhereStridedFunctor
198182
{
199183
private:
200-
const char *x1_cp = nullptr;
201-
const char *x2_cp = nullptr;
202-
char *dst_cp = nullptr;
203-
const char *cond_cp = nullptr;
184+
const T *x1_p = nullptr;
185+
const T *x2_p = nullptr;
186+
T *dst_p = nullptr;
187+
const condT *cond_p = nullptr;
204188
IndexerT indexer;
205189

206190
public:
207-
WhereStridedFunctor(const char *cond_data_p,
208-
const char *x1_data_p,
209-
const char *x2_data_p,
210-
char *dst_data_p,
191+
WhereStridedFunctor(const char *cond_cp,
192+
const char *x1_cp,
193+
const char *x2_cp,
194+
char *dst_cp,
211195
IndexerT indexer_)
212-
: x1_cp(x1_data_p), x2_cp(x2_data_p), dst_cp(dst_data_p),
213-
cond_cp(cond_data_p), indexer(indexer_)
196+
: x1_p(reinterpret_cast<const T *>(x1_cp)),
197+
x2_p(reinterpret_cast<const T *>(x2_cp)),
198+
dst_p(reinterpret_cast<T *>(dst_cp)),
199+
cond_p(reinterpret_cast<const condT *>(cond_cp)), indexer(indexer_)
214200
{
215201
}
216202

217203
void operator()(sycl::id<1> id) const
218204
{
219-
const T *x1_data = reinterpret_cast<const T *>(x1_cp);
220-
const T *x2_data = reinterpret_cast<const T *>(x2_cp);
221-
T *dst_data = reinterpret_cast<T *>(dst_cp);
222-
const condT *cond_data = reinterpret_cast<const condT *>(cond_cp);
223-
224205
size_t gid = id[0];
225206
auto offsets = indexer(static_cast<py::ssize_t>(gid));
226207

227208
using dpctl::tensor::type_utils::convert_impl;
228209
bool check =
229-
convert_impl<bool, condT>(cond_data[offsets.get_first_offset()]);
210+
convert_impl<bool, condT>(cond_p[offsets.get_first_offset()]);
230211

231-
dst_data[gid] = check ? x1_data[offsets.get_second_offset()]
232-
: x2_data[offsets.get_third_offset()];
212+
dst_p[gid] = check ? x1_p[offsets.get_second_offset()]
213+
: x2_p[offsets.get_third_offset()];
233214
}
234215
};
235216

0 commit comments

Comments
 (0)