|
27 | 27 |
|
28 | 28 | #include <dpnp_iface.hpp> |
29 | 29 | #include "dpnp_fptr.hpp" |
30 | | -#include "dpnp_iterator.hpp" |
31 | 30 | #include "dpnpc_memory_adapter.hpp" |
32 | 31 | #include "queue_sycl.hpp" |
33 | 32 |
|
@@ -140,258 +139,6 @@ DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef, |
140 | 139 | size_t, |
141 | 140 | const DPCTLEventVectorRef) = dpnp_argmin_c<_DataType, _idx_DataType>; |
142 | 141 |
|
143 | | - |
144 | | -template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
145 | | -class dpnp_where_c_broadcast_kernel; |
146 | | - |
147 | | -template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
148 | | -class dpnp_where_c_strides_kernel; |
149 | | - |
150 | | -template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
151 | | -class dpnp_where_c_kernel; |
152 | | - |
153 | | -template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
154 | | -DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref, |
155 | | - void* result_out, |
156 | | - const size_t result_size, |
157 | | - const size_t result_ndim, |
158 | | - const shape_elem_type* result_shape, |
159 | | - const shape_elem_type* result_strides, |
160 | | - const void* condition_in, |
161 | | - const size_t condition_size, |
162 | | - const size_t condition_ndim, |
163 | | - const shape_elem_type* condition_shape, |
164 | | - const shape_elem_type* condition_strides, |
165 | | - const void* input1_in, |
166 | | - const size_t input1_size, |
167 | | - const size_t input1_ndim, |
168 | | - const shape_elem_type* input1_shape, |
169 | | - const shape_elem_type* input1_strides, |
170 | | - const void* input2_in, |
171 | | - const size_t input2_size, |
172 | | - const size_t input2_ndim, |
173 | | - const shape_elem_type* input2_shape, |
174 | | - const shape_elem_type* input2_strides, |
175 | | - const DPCTLEventVectorRef dep_event_vec_ref) |
176 | | -{ |
177 | | - /* avoid warning unused variable*/ |
178 | | - (void)dep_event_vec_ref; |
179 | | - |
180 | | - DPCTLSyclEventRef event_ref = nullptr; |
181 | | - |
182 | | - if (!condition_size || !input1_size || !input2_size) |
183 | | - { |
184 | | - return event_ref; |
185 | | - } |
186 | | - |
187 | | - sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); |
188 | | - |
189 | | - bool* condition_data = static_cast<bool*>(const_cast<void*>(condition_in)); |
190 | | - _DataType_input1* input1_data = static_cast<_DataType_input1*>(const_cast<void*>(input1_in)); |
191 | | - _DataType_input2* input2_data = static_cast<_DataType_input2*>(const_cast<void*>(input2_in)); |
192 | | - _DataType_output* result = static_cast<_DataType_output*>(result_out); |
193 | | - |
194 | | - bool use_broadcasting = !array_equal(input1_shape, input1_ndim, input2_shape, input2_ndim); |
195 | | - use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input1_shape, input1_ndim); |
196 | | - use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input2_shape, input2_ndim); |
197 | | - |
198 | | - shape_elem_type* condition_shape_offsets = new shape_elem_type[condition_ndim]; |
199 | | - |
200 | | - get_shape_offsets_inkernel(condition_shape, condition_ndim, condition_shape_offsets); |
201 | | - bool use_strides = !array_equal(condition_strides, condition_ndim, condition_shape_offsets, condition_ndim); |
202 | | - delete[] condition_shape_offsets; |
203 | | - |
204 | | - shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim]; |
205 | | - |
206 | | - get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets); |
207 | | - use_strides = use_strides || !array_equal(input1_strides, input1_ndim, input1_shape_offsets, input1_ndim); |
208 | | - delete[] input1_shape_offsets; |
209 | | - |
210 | | - shape_elem_type* input2_shape_offsets = new shape_elem_type[input2_ndim]; |
211 | | - |
212 | | - get_shape_offsets_inkernel(input2_shape, input2_ndim, input2_shape_offsets); |
213 | | - use_strides = use_strides || !array_equal(input2_strides, input2_ndim, input2_shape_offsets, input2_ndim); |
214 | | - delete[] input2_shape_offsets; |
215 | | - |
216 | | - sycl::event event; |
217 | | - sycl::range<1> gws(result_size); |
218 | | - |
219 | | - if (use_broadcasting) |
220 | | - { |
221 | | - DPNPC_id<bool>* condition_it; |
222 | | - const size_t condition_it_it_size_in_bytes = sizeof(DPNPC_id<bool>); |
223 | | - condition_it = reinterpret_cast<DPNPC_id<bool>*>(dpnp_memory_alloc_c(q_ref, condition_it_it_size_in_bytes)); |
224 | | - new (condition_it) DPNPC_id<bool>(q_ref, condition_data, condition_shape, condition_strides, condition_ndim); |
225 | | - |
226 | | - condition_it->broadcast_to_shape(result_shape, result_ndim); |
227 | | - |
228 | | - DPNPC_id<_DataType_input1>* input1_it; |
229 | | - const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); |
230 | | - input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes)); |
231 | | - new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim); |
232 | | - |
233 | | - input1_it->broadcast_to_shape(result_shape, result_ndim); |
234 | | - |
235 | | - DPNPC_id<_DataType_input2>* input2_it; |
236 | | - const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); |
237 | | - input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes)); |
238 | | - new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim); |
239 | | - |
240 | | - input2_it->broadcast_to_shape(result_shape, result_ndim); |
241 | | - |
242 | | - auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
243 | | - const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
244 | | - { |
245 | | - const bool condition = (*condition_it)[i]; |
246 | | - const _DataType_output input1_elem = (*input1_it)[i]; |
247 | | - const _DataType_output input2_elem = (*input2_it)[i]; |
248 | | - result[i] = (condition) ? input1_elem : input2_elem; |
249 | | - } |
250 | | - }; |
251 | | - auto kernel_func = [&](sycl::handler& cgh) { |
252 | | - cgh.parallel_for<class dpnp_where_c_broadcast_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
253 | | - gws, kernel_parallel_for_func); |
254 | | - }; |
255 | | - |
256 | | - q.submit(kernel_func).wait(); |
257 | | - |
258 | | - condition_it->~DPNPC_id(); |
259 | | - input1_it->~DPNPC_id(); |
260 | | - input2_it->~DPNPC_id(); |
261 | | - |
262 | | - return event_ref; |
263 | | - } |
264 | | - else if (use_strides) |
265 | | - { |
266 | | - if ((result_ndim != condition_ndim) || (result_ndim != input1_ndim) || (result_ndim != input2_ndim)) |
267 | | - { |
268 | | - throw std::runtime_error("Result ndim=" + std::to_string(result_ndim) + |
269 | | - " mismatches with either condition ndim=" + std::to_string(condition_ndim) + |
270 | | - " or input1 ndim=" + std::to_string(input1_ndim) + |
271 | | - " or input2 ndim=" + std::to_string(input2_ndim)); |
272 | | - } |
273 | | - |
274 | | - /* memory transfer optimization, use USM-host for temporary speeds up tranfer to device */ |
275 | | - using usm_host_allocatorT = sycl::usm_allocator<shape_elem_type, sycl::usm::alloc::host>; |
276 | | - |
277 | | - size_t strides_size = 4 * result_ndim; |
278 | | - shape_elem_type* dev_strides_data = sycl::malloc_device<shape_elem_type>(strides_size, q); |
279 | | - |
280 | | - /* create host temporary for packed strides managed by shared pointer */ |
281 | | - auto strides_host_packed = |
282 | | - std::vector<shape_elem_type, usm_host_allocatorT>(strides_size, usm_host_allocatorT(q)); |
283 | | - |
284 | | - /* packed vector is concatenation of result_strides, condition_strides, input1_strides and input2_strides */ |
285 | | - std::copy(result_strides, result_strides + result_ndim, strides_host_packed.begin()); |
286 | | - std::copy(condition_strides, condition_strides + result_ndim, strides_host_packed.begin() + result_ndim); |
287 | | - std::copy(input1_strides, input1_strides + result_ndim, strides_host_packed.begin() + 2 * result_ndim); |
288 | | - std::copy(input2_strides, input2_strides + result_ndim, strides_host_packed.begin() + 3 * result_ndim); |
289 | | - |
290 | | - auto copy_strides_ev = |
291 | | - q.copy<shape_elem_type>(strides_host_packed.data(), dev_strides_data, strides_host_packed.size()); |
292 | | - |
293 | | - auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
294 | | - const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
295 | | - { |
296 | | - const shape_elem_type* result_strides_data = &dev_strides_data[0]; |
297 | | - const shape_elem_type* condition_strides_data = &dev_strides_data[result_ndim]; |
298 | | - const shape_elem_type* input1_strides_data = &dev_strides_data[2 * result_ndim]; |
299 | | - const shape_elem_type* input2_strides_data = &dev_strides_data[3 * result_ndim]; |
300 | | - |
301 | | - size_t condition_id = 0; |
302 | | - size_t input1_id = 0; |
303 | | - size_t input2_id = 0; |
304 | | - |
305 | | - for (size_t i = 0; i < result_ndim; ++i) |
306 | | - { |
307 | | - const size_t output_xyz_id = |
308 | | - get_xyz_id_by_id_inkernel(output_id, result_strides_data, result_ndim, i); |
309 | | - condition_id += output_xyz_id * condition_strides_data[i]; |
310 | | - input1_id += output_xyz_id * input1_strides_data[i]; |
311 | | - input2_id += output_xyz_id * input2_strides_data[i]; |
312 | | - } |
313 | | - |
314 | | - const bool condition = condition_data[condition_id]; |
315 | | - const _DataType_output input1_elem = input1_data[input1_id]; |
316 | | - const _DataType_output input2_elem = input2_data[input2_id]; |
317 | | - result[output_id] = (condition) ? input1_elem : input2_elem; |
318 | | - } |
319 | | - }; |
320 | | - auto kernel_func = [&](sycl::handler& cgh) { |
321 | | - cgh.depends_on(copy_strides_ev); |
322 | | - cgh.parallel_for<class dpnp_where_c_strides_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
323 | | - gws, kernel_parallel_for_func); |
324 | | - }; |
325 | | - |
326 | | - q.submit(kernel_func).wait(); |
327 | | - |
328 | | - sycl::free(dev_strides_data, q); |
329 | | - return event_ref; |
330 | | - } |
331 | | - else |
332 | | - { |
333 | | - auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
334 | | - const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
335 | | - |
336 | | - const bool condition = condition_data[i]; |
337 | | - const _DataType_output input1_elem = input1_data[i]; |
338 | | - const _DataType_output input2_elem = input2_data[i]; |
339 | | - result[i] = (condition) ? input1_elem : input2_elem; |
340 | | - }; |
341 | | - auto kernel_func = [&](sycl::handler& cgh) { |
342 | | - cgh.parallel_for<class dpnp_where_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
343 | | - gws, kernel_parallel_for_func); |
344 | | - }; |
345 | | - event = q.submit(kernel_func); |
346 | | - } |
347 | | - |
348 | | - event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); |
349 | | - return DPCTLEvent_Copy(event_ref); |
350 | | - |
351 | | - return event_ref; |
352 | | -} |
353 | | - |
354 | | -template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
355 | | -DPCTLSyclEventRef (*dpnp_where_ext_c)(DPCTLSyclQueueRef, |
356 | | - void*, |
357 | | - const size_t, |
358 | | - const size_t, |
359 | | - const shape_elem_type*, |
360 | | - const shape_elem_type*, |
361 | | - const void*, |
362 | | - const size_t, |
363 | | - const size_t, |
364 | | - const shape_elem_type*, |
365 | | - const shape_elem_type*, |
366 | | - const void*, |
367 | | - const size_t, |
368 | | - const size_t, |
369 | | - const shape_elem_type*, |
370 | | - const shape_elem_type*, |
371 | | - const void*, |
372 | | - const size_t, |
373 | | - const size_t, |
374 | | - const shape_elem_type*, |
375 | | - const shape_elem_type*, |
376 | | - const DPCTLEventVectorRef) = dpnp_where_c<_DataType_output, _DataType_input1, _DataType_input2>; |
377 | | - |
378 | | -template <DPNPFuncType FT1, DPNPFuncType... FTs> |
379 | | -static void func_map_searching_2arg_3type_core(func_map_t& fmap) |
380 | | -{ |
381 | | - ((fmap[DPNPFuncName::DPNP_FN_WHERE_EXT][FT1][FTs] = |
382 | | - {populate_func_types<FT1, FTs>(), |
383 | | - (void*)dpnp_where_ext_c<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>, |
384 | | - func_type_map_t::find_type<FT1>, |
385 | | - func_type_map_t::find_type<FTs>>}), |
386 | | - ...); |
387 | | -} |
388 | | - |
389 | | -template <DPNPFuncType... FTs> |
390 | | -static void func_map_searching_2arg_3type_helper(func_map_t& fmap) |
391 | | -{ |
392 | | - ((func_map_searching_2arg_3type_core<FTs, FTs...>(fmap)), ...); |
393 | | -} |
394 | | - |
395 | 142 | void func_map_init_searching(func_map_t& fmap) |
396 | 143 | { |
397 | 144 | fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_argmax_default_c<int32_t, int32_t>}; |
@@ -430,7 +177,5 @@ void func_map_init_searching(func_map_t& fmap) |
430 | 177 | fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_argmin_ext_c<double, int32_t>}; |
431 | 178 | fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_argmin_ext_c<double, int64_t>}; |
432 | 179 |
|
433 | | - func_map_searching_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap); |
434 | | - |
435 | 180 | return; |
436 | 181 | } |
0 commit comments