Skip to content

Commit 713302a

Browse files
Use get_ret_type_and_func in call_fptr functions (#1500)
* Updated get_ret_type_and_func takes cpp_bool has_aspect_fp64 argument * Use the updated get_ret_type_and_func function in call_fptr_1in_1out_strides, call_fptr_2in_1out and call_fptr_2in_1out_strides functions for further test updates to run on Iris Xe.
1 parent aa4687e commit 713302a

File tree

6 files changed

+57
-30
lines changed

6 files changed

+57
-30
lines changed

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,15 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
348348
""" get the FPTR data structure """
349349
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, param1_type, param1_type)
350350

351-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
351+
x1_obj = x1.get_array()
352+
353+
# get FPTR function and return type
354+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
355+
x1_obj.sycl_device.has_aspect_fp64)
356+
cdef DPNPFuncType return_type = ret_type_and_func[0]
357+
cdef fptr_1in_1out_strides_t func = < fptr_1in_1out_strides_t > ret_type_and_func[1]
358+
359+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > return_type)
352360

353361
cdef shape_type_c x1_shape = x1.shape
354362
cdef shape_type_c x1_strides = utils.strides_to_vector(x1.strides, x1_shape)
@@ -358,9 +366,8 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
358366

359367
if out is None:
360368
""" Create result array with type given by FPTR data """
361-
x1_obj = x1.get_array()
362369
result = utils.create_output_descriptor(result_shape,
363-
kernel_data.return_type,
370+
return_type,
364371
None,
365372
device=x1_obj.sycl_device,
366373
usm_type=x1_obj.usm_type,
@@ -383,7 +390,6 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name,
383390
cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result_shape)
384391

385392
""" Call FPTR function """
386-
cdef fptr_1in_1out_strides_t func = <fptr_1in_1out_strides_t > kernel_data.ptr
387393
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
388394
result.get_data(),
389395
result.size,
@@ -419,20 +425,26 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name,
419425
# get the FPTR data structure
420426
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
421427

422-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
428+
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
429+
430+
# get FPTR function and return type
431+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
432+
result_sycl_device.has_aspect_fp64)
433+
cdef DPNPFuncType return_type = ret_type_and_func[0]
434+
cdef fptr_2in_1out_t func = < fptr_2in_1out_t > ret_type_and_func[1]
435+
436+
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > return_type)
423437

424438
# Create result array
425439
cdef shape_type_c x1_shape = x1_obj.shape
426440
cdef shape_type_c x2_shape = x2_obj.shape
427441
cdef shape_type_c result_shape = utils.get_common_shape(x1_shape, x2_shape)
428442
cdef utils.dpnp_descriptor result
429443

430-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
431-
432444
if out is None:
433445
""" Create result array with type given by FPTR data """
434446
result = utils.create_output_descriptor(result_shape,
435-
kernel_data.return_type,
447+
return_type,
436448
None,
437449
device=result_sycl_device,
438450
usm_type=result_usm_type,
@@ -451,7 +463,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name,
451463
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
452464

453465
""" Call FPTR function """
454-
cdef fptr_2in_1out_t func = <fptr_2in_1out_t > kernel_data.ptr
455466
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
456467
result.get_data(),
457468
x1_obj.get_data(),
@@ -485,6 +496,14 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
485496
# get the FPTR data structure
486497
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type)
487498

499+
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
500+
501+
# get FPTR function and return type
502+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
503+
result_sycl_device.has_aspect_fp64)
504+
cdef DPNPFuncType return_type = ret_type_and_func[0]
505+
cdef fptr_2in_1out_strides_t func = < fptr_2in_1out_strides_t > ret_type_and_func[1]
506+
488507
# Create result array
489508
cdef shape_type_c x1_shape = x1_obj.shape
490509

@@ -495,12 +514,6 @@ cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name,
495514
cdef shape_type_c result_shape = utils.get_common_shape(x1_shape, x2_shape)
496515
cdef utils.dpnp_descriptor result
497516

498-
result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj)
499-
500-
# get FPTR function and return type
501-
cdef fptr_2in_1out_strides_t func = < fptr_2in_1out_strides_t > kernel_data.ptr
502-
cdef DPNPFuncType return_type = kernel_data.return_type
503-
504517
# check 'out' parameter data
505518
if out is not None:
506519
if out.shape != result_shape:

dpnp/dpnp_algo/dpnp_algo_statistics.pxi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,8 @@ cpdef utils.dpnp_descriptor dpnp_median(utils.dpnp_descriptor array1):
265265

266266
array1_obj = array1.get_array()
267267

268-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(array1_obj, kernel_data)
268+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
269+
array1_obj.sycl_device.has_aspect_fp64)
269270
cdef DPNPFuncType return_type = ret_type_and_func[0]
270271
cdef custom_statistic_1in_1out_func_ptr_t func = < custom_statistic_1in_1out_func_ptr_t > ret_type_and_func[1]
271272

dpnp/dpnp_utils/dpnp_algo_utils.pxd

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ cdef tuple get_common_usm_allocation(dpnp_descriptor x1, dpnp_descriptor x2)
163163
Get common USM allocation in the form of (sycl_device, usm_type, sycl_queue)
164164
"""
165165

166-
cdef (DPNPFuncType, void *) get_ret_type_and_func(x1_obj, DPNPFuncData kernel_data)
166+
cdef (DPNPFuncType, void *) get_ret_type_and_func(DPNPFuncData kernel_data,
167+
cpp_bool has_aspect_fp64)
167168
"""
168169
Get the corresponding return type and function pointer based on the
169-
capability of the allocated input array device for the integer types.
170+
capability of the allocated result array device.
170171
"""

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -663,13 +663,20 @@ cdef tuple get_common_usm_allocation(dpnp_descriptor x1, dpnp_descriptor x2):
663663
return (common_sycl_queue.sycl_device, common_usm_type, common_sycl_queue)
664664

665665

666-
cdef (DPNPFuncType, void *) get_ret_type_and_func(x1_obj, DPNPFuncData kernel_data):
667-
if dpnp.issubdtype(x1_obj.dtype, dpnp.integer) and not x1_obj.sycl_device.has_aspect_fp64:
666+
cdef (DPNPFuncType, void *) get_ret_type_and_func(DPNPFuncData kernel_data,
667+
cpp_bool has_aspect_fp64):
668+
"""
669+
This function is responsible for determining the appropriate return type
670+
and function pointer based on the capability of the allocated result array device.
671+
"""
672+
return_type = kernel_data.return_type
673+
func = kernel_data.ptr
674+
675+
if kernel_data.ptr_no_fp64 != NULL and not has_aspect_fp64:
676+
668677
return_type = kernel_data.return_type_no_fp64
669678
func = kernel_data.ptr_no_fp64
670-
else:
671-
return_type = kernel_data.return_type
672-
func = kernel_data.ptr
679+
673680
return return_type, func
674681

675682

dpnp/dpnp_utils/dpnp_utils_statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def _get_2dmin_array(x, dtype):
7979
dtypes = [m.dtype, dpnp.default_float_type(sycl_queue=queue)]
8080
if y is not None:
8181
dtypes.append(y.dtype)
82-
dtype = dpt.result_type(*dtypes)
83-
# TODO: remove when dpctl.result_type() is fixed
82+
dtype = dpnp.result_type(*dtypes)
83+
# TODO: remove when dpctl.result_type() is returned dtype based on fp64
8484
fp64 = queue.sycl_device.has_aspect_fp64
8585
if not fp64:
8686
if dtype == dpnp.float64:

dpnp/linalg/dpnp_algo_linalg.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ cpdef tuple dpnp_eig(utils.dpnp_descriptor x1):
196196

197197
x1_obj = x1.get_array()
198198

199-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(x1_obj, kernel_data)
199+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
200+
x1_obj.sycl_device.has_aspect_fp64)
200201
cdef DPNPFuncType return_type = ret_type_and_func[0]
201202
cdef custom_linalg_2in_1out_func_ptr_t func = < custom_linalg_2in_1out_func_ptr_t > ret_type_and_func[1]
202203

@@ -242,7 +243,8 @@ cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input):
242243

243244
input_obj = input.get_array()
244245

245-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(input_obj, kernel_data)
246+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
247+
input_obj.sycl_device.has_aspect_fp64)
246248
cdef DPNPFuncType return_type = ret_type_and_func[0]
247249
cdef custom_linalg_1in_1out_with_size_func_ptr_t_ func = < custom_linalg_1in_1out_with_size_func_ptr_t_ > ret_type_and_func[1]
248250

@@ -281,7 +283,8 @@ cpdef utils.dpnp_descriptor dpnp_inv(utils.dpnp_descriptor input):
281283

282284
input_obj = input.get_array()
283285

284-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(input_obj, kernel_data)
286+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
287+
input_obj.sycl_device.has_aspect_fp64)
285288
cdef DPNPFuncType return_type = ret_type_and_func[0]
286289
cdef custom_linalg_1in_1out_func_ptr_t func = < custom_linalg_1in_1out_func_ptr_t > ret_type_and_func[1]
287290

@@ -462,7 +465,8 @@ cpdef tuple dpnp_qr(utils.dpnp_descriptor x1, str mode):
462465

463466
x1_obj = x1.get_array()
464467

465-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(x1_obj, kernel_data)
468+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
469+
x1_obj.sycl_device.has_aspect_fp64)
466470
cdef DPNPFuncType return_type = ret_type_and_func[0]
467471
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > ret_type_and_func[1]
468472

@@ -515,7 +519,8 @@ cpdef tuple dpnp_svd(utils.dpnp_descriptor x1, cpp_bool full_matrices, cpp_bool
515519

516520
x1_obj = x1.get_array()
517521

518-
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(x1_obj, kernel_data)
522+
cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data,
523+
x1_obj.sycl_device.has_aspect_fp64)
519524
cdef DPNPFuncType return_type = ret_type_and_func[0]
520525
cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > ret_type_and_func[1]
521526

0 commit comments

Comments
 (0)