@@ -348,7 +348,15 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
348
348
""" get the FPTR data structure """
349
349
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, param1_type, param1_type)
350
350
351
- result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
351
+ x1_obj = x1.get_array()
352
+
353
+ # get FPTR function and return type
354
+ cdef (DPNPFuncType, void * ) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
355
+ x1_obj.sycl_device.has_aspect_fp64)
356
+ cdef DPNPFuncType return_type = ret_type_and_func[0 ]
357
+ cdef fptr_1in_1out_strides_t func = < fptr_1in_1out_strides_t > ret_type_and_func[1 ]
358
+
359
+ result_type = dpnp_DPNPFuncType_to_dtype( < size_t > return_type)
352
360
353
361
cdef shape_type_c x1_shape = x1.shape
354
362
cdef shape_type_c x1_strides = utils.strides_to_vector(x1.strides, x1_shape)
@@ -358,9 +366,8 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
358
366
359
367
if out is None :
360
368
""" Create result array with type given by FPTR data """
361
- x1_obj = x1.get_array()
362
369
result = utils.create_output_descriptor(result_shape,
363
- kernel_data. return_type,
370
+ return_type,
364
371
None ,
365
372
device = x1_obj.sycl_device,
366
373
usm_type = x1_obj.usm_type,
@@ -383,7 +390,6 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
383
390
cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result_shape)
384
391
385
392
""" Call FPTR function """
386
- cdef fptr_1in_1out_strides_t func = < fptr_1in_1out_strides_t > kernel_data.ptr
387
393
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
388
394
result.get_data(),
389
395
result.size,
@@ -419,20 +425,26 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name,
419
425
# get the FPTR data structure
420
426
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
421
427
422
- result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
428
+ result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
429
+
430
+ # get FPTR function and return type
431
+ cdef (DPNPFuncType, void * ) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
432
+ result_sycl_device.has_aspect_fp64)
433
+ cdef DPNPFuncType return_type = ret_type_and_func[0 ]
434
+ cdef fptr_2in_1out_t func = < fptr_2in_1out_t > ret_type_and_func[1 ]
435
+
436
+ result_type = dpnp_DPNPFuncType_to_dtype( < size_t > return_type)
423
437
424
438
# Create result array
425
439
cdef shape_type_c x1_shape = x1_obj.shape
426
440
cdef shape_type_c x2_shape = x2_obj.shape
427
441
cdef shape_type_c result_shape = utils.get_common_shape(x1_shape, x2_shape)
428
442
cdef utils.dpnp_descriptor result
429
443
430
- result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
431
-
432
444
if out is None :
433
445
""" Create result array with type given by FPTR data """
434
446
result = utils.create_output_descriptor(result_shape,
435
- kernel_data. return_type,
447
+ return_type,
436
448
None ,
437
449
device = result_sycl_device,
438
450
usm_type = result_usm_type,
@@ -451,7 +463,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name,
451
463
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
452
464
453
465
""" Call FPTR function """
454
- cdef fptr_2in_1out_t func = < fptr_2in_1out_t > kernel_data.ptr
455
466
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
456
467
result.get_data(),
457
468
x1_obj.get_data(),
@@ -485,6 +496,14 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
485
496
# get the FPTR data structure
486
497
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
487
498
499
+ result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
500
+
501
+ # get FPTR function and return type
502
+ cdef (DPNPFuncType, void * ) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
503
+ result_sycl_device.has_aspect_fp64)
504
+ cdef DPNPFuncType return_type = ret_type_and_func[0 ]
505
+ cdef fptr_2in_1out_strides_t func = < fptr_2in_1out_strides_t > ret_type_and_func[1 ]
506
+
488
507
# Create result array
489
508
cdef shape_type_c x1_shape = x1_obj.shape
490
509
@@ -495,12 +514,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
495
514
cdef shape_type_c result_shape = utils.get_common_shape(x1_shape, x2_shape)
496
515
cdef utils.dpnp_descriptor result
497
516
498
- result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
499
-
500
- # get FPTR function and return type
501
- cdef fptr_2in_1out_strides_t func = < fptr_2in_1out_strides_t > kernel_data.ptr
502
- cdef DPNPFuncType return_type = kernel_data.return_type
503
-
504
517
# check 'out' parameter data
505
518
if out is not None :
506
519
if out.shape != result_shape:
0 commit comments