Skip to content

Commit a9e824b

Browse files
committed
re-write dpnp.abs
1 parent 946ff08 commit a9e824b

12 files changed

+229
-106
lines changed

dpnp/backend/extensions/vm/abs.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event abs_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::abs(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct AbsContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::AbsOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return abs_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ namespace vm
4343
{
4444
namespace types
4545
{
46+
/**
47+
* @brief A factory to define pairs of supported types for which
48+
* MKL VM library provides support in oneapi::mkl::vm::abs<T> function.
49+
*
50+
* @tparam T Type of input vector `a` and of result vector `y`.
51+
*/
52+
template <typename T>
53+
struct AbsOutputType
54+
{
55+
using value_type = typename std::disjunction<
56+
dpctl_td_ns::TypeMapResultEntry<T, double>,
57+
dpctl_td_ns::TypeMapResultEntry<T, float>,
58+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
59+
};
60+
4661
/**
4762
* @brief A factory to define pairs of supported types for which
4863
* MKL VM library provides support in oneapi::mkl::vm::acos<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <pybind11/pybind11.h>
3131
#include <pybind11/stl.h>
3232

33+
#include "abs.hpp"
3334
#include "acos.hpp"
3435
#include "acosh.hpp"
3536
#include "add.hpp"
@@ -66,6 +67,7 @@ namespace vm_ext = dpnp::backend::ext::vm;
6667
using vm_ext::binary_impl_fn_ptr_t;
6768
using vm_ext::unary_impl_fn_ptr_t;
6869

70+
static unary_impl_fn_ptr_t abs_dispatch_vector[dpctl_td_ns::num_types];
6971
static unary_impl_fn_ptr_t acos_dispatch_vector[dpctl_td_ns::num_types];
7072
static unary_impl_fn_ptr_t acosh_dispatch_vector[dpctl_td_ns::num_types];
7173
static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types];
@@ -99,6 +101,34 @@ PYBIND11_MODULE(_vm_impl, m)
99101
using arrayT = dpctl::tensor::usm_ndarray;
100102
using event_vecT = std::vector<sycl::event>;
101103

104+
// UnaryUfunc: ==== Abs(x) ====
105+
{
106+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
107+
vm_ext::AbsContigFactory>(
108+
abs_dispatch_vector);
109+
110+
auto abs_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
111+
const event_vecT &depends = {}) {
112+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
113+
abs_dispatch_vector);
114+
};
115+
m.def("_abs", abs_pyapi,
116+
"Call `abs` function from OneMKL VM library to compute "
117+
"the absolute of vector elements",
118+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
119+
py::arg("depends") = py::list());
120+
121+
auto abs_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
122+
arrayT dst) {
123+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
124+
abs_dispatch_vector);
125+
};
126+
m.def("_mkl_abs_to_call", abs_need_to_call_pyapi,
127+
"Check input arguments to answer if `abs` function from "
128+
"OneMKL VM library can be used",
129+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
130+
}
131+
102132
// UnaryUfunc: ==== Acos(x) ====
103133
{
104134
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,9 @@
5858
*/
5959
enum class DPNPFuncName : size_t
6060
{
61-
DPNP_FN_NONE, /**< Very first element of the enumeration */
62-
DPNP_FN_ABSOLUTE, /**< Used in numpy.absolute() impl */
63-
DPNP_FN_ABSOLUTE_EXT, /**< Used in numpy.absolute() impl, requires extra
64-
parameters */
65-
DPNP_FN_ADD, /**< Used in numpy.add() impl */
61+
DPNP_FN_NONE, /**< Very first element of the enumeration */
62+
DPNP_FN_ABSOLUTE, /**< Used in numpy.absolute() impl */
63+
DPNP_FN_ADD, /**< Used in numpy.add() impl */
6664
DPNP_FN_ADD_EXT, /**< Used in numpy.add() impl, requires extra parameters */
6765
DPNP_FN_ALL, /**< Used in numpy.all() impl */
6866
DPNP_FN_ALLCLOSE, /**< Used in numpy.allclose() impl */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,6 @@ template <typename _DataType>
216216
void (*dpnp_elemwise_absolute_default_c)(const void *, void *, size_t) =
217217
dpnp_elemwise_absolute_c<_DataType>;
218218

219-
template <typename _DataType_input, typename _DataType_output = _DataType_input>
220-
DPCTLSyclEventRef (*dpnp_elemwise_absolute_ext_c)(DPCTLSyclQueueRef,
221-
const void *,
222-
void *,
223-
size_t,
224-
const DPCTLEventVectorRef) =
225-
dpnp_elemwise_absolute_c<_DataType_input, _DataType_output>;
226-
227219
template <typename _DataType_output,
228220
typename _DataType_input1,
229221
typename _DataType_input2>
@@ -1151,21 +1143,6 @@ void func_map_init_mathematical(func_map_t &fmap)
11511143
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_DBL][eft_DBL] = {
11521144
eft_DBL, (void *)dpnp_elemwise_absolute_default_c<double>};
11531145

1154-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_INT][eft_INT] = {
1155-
eft_INT, (void *)dpnp_elemwise_absolute_ext_c<int32_t>};
1156-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_LNG][eft_LNG] = {
1157-
eft_LNG, (void *)dpnp_elemwise_absolute_ext_c<int64_t>};
1158-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {
1159-
eft_FLT, (void *)dpnp_elemwise_absolute_ext_c<float>};
1160-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {
1161-
eft_DBL, (void *)dpnp_elemwise_absolute_ext_c<double>};
1162-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C64][eft_C64] = {
1163-
eft_FLT,
1164-
(void *)dpnp_elemwise_absolute_ext_c<std::complex<float>, float>};
1165-
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C128][eft_C128] = {
1166-
eft_DBL,
1167-
(void *)dpnp_elemwise_absolute_ext_c<std::complex<double>, double>};
1168-
11691146
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {
11701147
eft_INT, (void *)dpnp_around_default_c<int32_t>};
11711148
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ from dpnp.dpnp_utils.dpnp_algo_utils cimport dpnp_descriptor
3333

3434
cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this namespace for Enum import
3535
cdef enum DPNPFuncName "DPNPFuncName":
36-
DPNP_FN_ABSOLUTE
37-
DPNP_FN_ABSOLUTE_EXT
3836
DPNP_FN_ALLCLOSE
3937
DPNP_FN_ALLCLOSE_EXT
4038
DPNP_FN_ARANGE

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ and the rest of the library
3636
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file
3737

3838
__all__ += [
39-
"dpnp_absolute",
4039
"dpnp_copysign",
4140
"dpnp_cross",
4241
"dpnp_cumprod",
@@ -59,9 +58,6 @@ __all__ += [
5958
]
6059

6160

62-
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_custom_elemwise_absolute_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef,
63-
void * , void * , size_t,
64-
const c_dpctl.DPCTLEventVectorRef)
6561
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_1in_2out_t)(c_dpctl.DPCTLSyclQueueRef,
6662
void * , void * , void * , size_t,
6763
const c_dpctl.DPCTLEventVectorRef)
@@ -70,41 +66,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_trapz_2in_1out_with_2size_t)(c_d
7066
const c_dpctl.DPCTLEventVectorRef)
7167

7268

73-
cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
74-
cdef shape_type_c x1_shape = x1.shape
75-
cdef size_t x1_shape_size = x1.ndim
76-
77-
# convert string type names (array.dtype) to C enum DPNPFuncType
78-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
79-
80-
# get the FPTR data structure
81-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ABSOLUTE_EXT, param1_type, param1_type)
82-
83-
x1_obj = x1.get_array()
84-
85-
# ceate result array with type given by FPTR data
86-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(x1_shape,
87-
kernel_data.return_type,
88-
None,
89-
device=x1_obj.sycl_device,
90-
usm_type=x1_obj.usm_type,
91-
sycl_queue=x1_obj.sycl_queue)
92-
93-
result_sycl_queue = result.get_array().sycl_queue
94-
95-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
96-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
97-
98-
cdef fptr_custom_elemwise_absolute_1in_1out_t func = <fptr_custom_elemwise_absolute_1in_1out_t > kernel_data.ptr
99-
# call FPTR function
100-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, x1.get_data(), result.get_data(), x1.size, NULL)
101-
102-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
103-
c_dpctl.DPCTLEvent_Delete(event_ref)
104-
105-
return result
106-
107-
10869
cpdef utils.dpnp_descriptor dpnp_copysign(utils.dpnp_descriptor x1_obj,
10970
utils.dpnp_descriptor x2_obj,
11071
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
__all__ = [
4444
"check_nd_call_func",
45+
"dpnp_abs",
4546
"dpnp_acos",
4647
"dpnp_acosh",
4748
"dpnp_add",
@@ -168,6 +169,63 @@ def check_nd_call_func(
168169
)
169170

170171

172+
_abs_docstring = """
173+
abs(x, out=None, order='K')
174+
175+
Calculates the absolute value for each element `x_i` of input array `x`.
176+
177+
Args:
178+
x (dpnp.ndarray):
179+
Input array, expected to have numeric data type.
180+
out ({None, dpnp.ndarray}, optional):
181+
Output array to populate. Array must have the correct
182+
shape and the expected data type.
183+
order ("C","F","A","K", optional): memory layout of the new
184+
output array, if parameter `out` is `None`.
185+
Default: "K".
186+
Return:
187+
dpnp.ndarray:
188+
An array containing the element-wise absolute values.
189+
For complex input, the absolute value is its magnitude.
190+
If `x` has a real-valued data type, the returned array has the
191+
same data type as `x`. If `x` has a complex floating-point data type,
192+
the returned array has a real-valued floating-point data type whose
193+
precision matches the precision of `x`.
194+
"""
195+
196+
197+
def _call_abs(src, dst, sycl_queue, depends=None):
198+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
199+
200+
if depends is None:
201+
depends = []
202+
203+
if vmi._mkl_abs_to_call(sycl_queue, src, dst):
204+
# call pybind11 extension for abs() function from OneMKL VM
205+
return vmi._abs(sycl_queue, src, dst, depends)
206+
return ti._abs(src, dst, sycl_queue, depends)
207+
208+
209+
abs_func = UnaryElementwiseFunc(
210+
"abs", ti._abs_result_type, _call_abs, _abs_docstring
211+
)
212+
213+
214+
def dpnp_abs(x, out=None, order="K"):
215+
"""
216+
Invokes abs() function from pybind11 extension of OneMKL VM if possible.
217+
218+
Otherwise fully relies on dpctl.tensor implementation for abs() function.
219+
220+
"""
221+
# dpctl.tensor only works with usm_ndarray
222+
x1_usm = dpnp.get_usm_ndarray(x)
223+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
224+
225+
res_usm = abs_func(x1_usm, out=out_usm, order=order)
226+
return dpnp_array._create_from_usm_ndarray(res_usm)
227+
228+
171229
_acos_docstring = """
172230
acos(x, out=None, order='K')
173231

0 commit comments

Comments
 (0)