@@ -57,29 +57,26 @@ class WhereContigFunctor
57
57
{
58
58
private:
59
59
size_t nelems = 0 ;
60
- const char *x1_cp = nullptr ;
61
- const char *x2_cp = nullptr ;
62
- char *dst_cp = nullptr ;
63
- const char *cond_cp = nullptr ;
60
+ const T *x1_p = nullptr ;
61
+ const T *x2_p = nullptr ;
62
+ T *dst_p = nullptr ;
63
+ const condT *cond_p = nullptr ;
64
64
65
65
public:
66
66
WhereContigFunctor (size_t nelems_,
67
- const char *cond_data_p,
68
- const char *x1_data_p,
69
- const char *x2_data_p,
70
- char *dst_data_p)
71
- : nelems(nelems_), x1_cp(x1_data_p), x2_cp(x2_data_p),
72
- dst_cp (dst_data_p), cond_cp(cond_data_p)
67
+ const char *cond_cp,
68
+ const char *x1_cp,
69
+ const char *x2_cp,
70
+ char *dst_cp)
71
+ : nelems(nelems_), x1_p(reinterpret_cast <const T *>(x1_cp)),
72
+ x2_p (reinterpret_cast <const T *>(x2_cp)),
73
+ dst_p(reinterpret_cast <T *>(dst_cp)),
74
+ cond_p(reinterpret_cast <const condT *>(cond_cp))
73
75
{
74
76
}
75
77
76
78
void operator ()(sycl::nd_item<1 > ndit) const
77
79
{
78
- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
79
- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
80
- T *dst_data = reinterpret_cast <T *>(dst_cp);
81
- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
82
-
83
80
using dpctl::tensor::type_utils::is_complex;
84
81
if constexpr (is_complex<condT>::value || is_complex<T>::value) {
85
82
std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
@@ -91,8 +88,8 @@ class WhereContigFunctor
91
88
offset += sgSize)
92
89
{
93
90
using dpctl::tensor::type_utils::convert_impl;
94
- bool check = convert_impl<bool , condT>(cond_data [offset]);
95
- dst_data [offset] = check ? x1_data [offset] : x2_data [offset];
91
+ bool check = convert_impl<bool , condT>(cond_p [offset]);
92
+ dst_p [offset] = check ? x1_p [offset] : x2_p [offset];
96
93
}
97
94
}
98
95
else {
@@ -122,33 +119,20 @@ class WhereContigFunctor
122
119
#pragma unroll
123
120
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
124
121
auto idx = base + it * sgSize;
125
- x1_vec = sg.load <vec_sz>(x_ptrT (&x1_data[idx]));
126
- x2_vec = sg.load <vec_sz>(x_ptrT (&x2_data[idx]));
127
- cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_data[idx]));
128
- if constexpr (std::is_same_v<bool , condT>) {
129
- #pragma unroll
130
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
131
- dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
132
- }
133
- }
134
- else {
135
- using dpctl::tensor::type_utils::vec_cast;
136
- sycl::vec<bool , vec_sz> tmp =
137
- vec_cast<bool ,
138
- typename decltype (cond_vec)::element_type,
139
- vec_sz>(cond_vec);
122
+ x1_vec = sg.load <vec_sz>(x_ptrT (&x1_p[idx]));
123
+ x2_vec = sg.load <vec_sz>(x_ptrT (&x2_p[idx]));
124
+ cond_vec = sg.load <vec_sz>(cond_ptrT (&cond_p[idx]));
140
125
#pragma unroll
141
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
142
- dst_vec[k] = tmp[k] ? x1_vec[k] : x2_vec[k];
143
- }
126
+ for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
127
+ dst_vec[k] = cond_vec[k] ? x1_vec[k] : x2_vec[k];
144
128
}
145
- sg.store <vec_sz>(dst_ptrT (&dst_data [idx]), dst_vec);
129
+ sg.store <vec_sz>(dst_ptrT (&dst_p [idx]), dst_vec);
146
130
}
147
131
}
148
132
else {
149
133
for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems;
150
134
k += sgSize) {
151
- dst_data [k] = cond_data [k] ? x1_data [k] : x2_data [k];
135
+ dst_p [k] = cond_p [k] ? x1_p [k] : x2_p [k];
152
136
}
153
137
}
154
138
}
@@ -167,10 +151,10 @@ typedef sycl::event (*where_contig_impl_fn_ptr_t)(
167
151
template <typename T, typename condT>
168
152
sycl::event where_contig_impl (sycl::queue q,
169
153
size_t nelems,
170
- const char *cond_p ,
171
- const char *x1_p ,
172
- const char *x2_p ,
173
- char *dst_p ,
154
+ const char *cond_cp ,
155
+ const char *x1_cp ,
156
+ const char *x2_cp ,
157
+ char *dst_cp ,
174
158
const std::vector<sycl::event> &depends)
175
159
{
176
160
sycl::event where_ev = q.submit ([&](sycl::handler &cgh) {
@@ -186,8 +170,8 @@ sycl::event where_contig_impl(sycl::queue q,
186
170
187
171
cgh.parallel_for <where_contig_kernel<T, condT, vec_sz, n_vecs>>(
188
172
sycl::nd_range<1 >(gws_range, lws_range),
189
- WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_p, x1_p ,
190
- x2_p, dst_p ));
173
+ WhereContigFunctor<T, condT, vec_sz, n_vecs>(nelems, cond_cp, x1_cp ,
174
+ x2_cp, dst_cp ));
191
175
});
192
176
193
177
return where_ev;
@@ -197,39 +181,36 @@ template <typename T, typename condT, typename IndexerT>
197
181
class WhereStridedFunctor
198
182
{
199
183
private:
200
- const char *x1_cp = nullptr ;
201
- const char *x2_cp = nullptr ;
202
- char *dst_cp = nullptr ;
203
- const char *cond_cp = nullptr ;
184
+ const T *x1_p = nullptr ;
185
+ const T *x2_p = nullptr ;
186
+ T *dst_p = nullptr ;
187
+ const condT *cond_p = nullptr ;
204
188
IndexerT indexer;
205
189
206
190
public:
207
- WhereStridedFunctor (const char *cond_data_p ,
208
- const char *x1_data_p ,
209
- const char *x2_data_p ,
210
- char *dst_data_p ,
191
+ WhereStridedFunctor (const char *cond_cp ,
192
+ const char *x1_cp ,
193
+ const char *x2_cp ,
194
+ char *dst_cp ,
211
195
IndexerT indexer_)
212
- : x1_cp(x1_data_p), x2_cp(x2_data_p), dst_cp(dst_data_p),
213
- cond_cp (cond_data_p), indexer(indexer_)
196
+ : x1_p(reinterpret_cast <const T *>(x1_cp)),
197
+ x2_p (reinterpret_cast <const T *>(x2_cp)),
198
+ dst_p(reinterpret_cast <T *>(dst_cp)),
199
+ cond_p(reinterpret_cast <const condT *>(cond_cp)), indexer(indexer_)
214
200
{
215
201
}
216
202
217
203
void operator ()(sycl::id<1 > id) const
218
204
{
219
- const T *x1_data = reinterpret_cast <const T *>(x1_cp);
220
- const T *x2_data = reinterpret_cast <const T *>(x2_cp);
221
- T *dst_data = reinterpret_cast <T *>(dst_cp);
222
- const condT *cond_data = reinterpret_cast <const condT *>(cond_cp);
223
-
224
205
size_t gid = id[0 ];
225
206
auto offsets = indexer (static_cast <py::ssize_t >(gid));
226
207
227
208
using dpctl::tensor::type_utils::convert_impl;
228
209
bool check =
229
- convert_impl<bool , condT>(cond_data [offsets.get_first_offset ()]);
210
+ convert_impl<bool , condT>(cond_p [offsets.get_first_offset ()]);
230
211
231
- dst_data [gid] = check ? x1_data [offsets.get_second_offset ()]
232
- : x2_data [offsets.get_third_offset ()];
212
+ dst_p [gid] = check ? x1_p [offsets.get_second_offset ()]
213
+ : x2_p [offsets.get_third_offset ()];
233
214
}
234
215
};
235
216
0 commit comments