Skip to content

Commit 3cba1ce

Browse files
committed
use_dpctl_conj_for_dpnp
1 parent 42e02d9 commit 3cba1ce

File tree

11 files changed

+227
-45
lines changed

11 files changed

+227
-45
lines changed

dpnp/backend/extensions/vm/conj.hpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event conj_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::conj(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct ConjContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::ConjOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return conj_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ struct DivOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::conj<T> function.
74+
*
75+
* @tparam T Type of input vector `a` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct ConjOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::
82+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
83+
dpctl_td_ns::
84+
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
85+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
86+
};
87+
7188
/**
7289
* @brief A factory to define pairs of supported types for which
7390
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <pybind11/stl.h>
3232

3333
#include "common.hpp"
34+
#include "conj.hpp"
3435
#include "cos.hpp"
3536
#include "div.hpp"
3637
#include "ln.hpp"
@@ -48,6 +49,7 @@ using vm_ext::unary_impl_fn_ptr_t;
4849
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
4950

5051
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
52+
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
5153
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
5254
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
5355
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
@@ -116,6 +118,34 @@ PYBIND11_MODULE(_vm_impl, m)
116118
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
117119
}
118120

121+
// UnaryUfunc: ==== Conj(x) ====
122+
{
123+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
124+
vm_ext::ConjContigFactory>(
125+
conj_dispatch_vector);
126+
127+
auto conj_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
128+
const event_vecT &depends = {}) {
129+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
130+
conj_dispatch_vector);
131+
};
132+
m.def("_conj", conj_pyapi,
133+
"Call `conj` function from OneMKL VM library to compute "
134+
"conjugate of vector elements",
135+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
136+
py::arg("depends") = py::list());
137+
138+
auto conj_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
139+
arrayT dst) {
140+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
141+
conj_dispatch_vector);
142+
};
143+
m.def("_mkl_conj_to_call", conj_need_to_call_pyapi,
144+
"Check input arguments to answer if `conj` function from "
145+
"OneMKL VM library can be used",
146+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
147+
}
148+
119149
// UnaryUfunc: ==== Ln(x) ====
120150
{
121151
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,13 @@ enum class DPNPFuncName : size_t
116116
DPNP_FN_CEIL_EXT, /**< Used in numpy.ceil() impl, requires extra parameters
117117
*/
118118
DPNP_FN_CHOLESKY, /**< Used in numpy.linalg.cholesky() impl */
119-
DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires
120-
extra parameters */
121-
DPNP_FN_CONJIGUATE, /**< Used in numpy.conjugate() impl */
122-
DPNP_FN_CONJIGUATE_EXT, /**< Used in numpy.conjugate() impl, requires extra
123-
parameters */
124-
DPNP_FN_CHOOSE, /**< Used in numpy.choose() impl */
125-
DPNP_FN_CHOOSE_EXT, /**< Used in numpy.choose() impl, requires extra
126-
parameters */
127-
DPNP_FN_COPY, /**< Used in numpy.copy() impl */
119+
DPNP_FN_CHOLESKY_EXT, /**< Used in numpy.linalg.cholesky() impl, requires
120+
extra parameters */
121+
DPNP_FN_CONJUGATE, /**< Used in numpy.conjugate() impl */
122+
DPNP_FN_CHOOSE, /**< Used in numpy.choose() impl */
123+
DPNP_FN_CHOOSE_EXT, /**< Used in numpy.choose() impl, requires extra
124+
parameters */
125+
DPNP_FN_COPY, /**< Used in numpy.copy() impl */
128126
DPNP_FN_COPY_EXT, /**< Used in numpy.copy() impl, requires extra parameters
129127
*/
130128
DPNP_FN_COPYSIGN, /**< Used in numpy.copysign() impl */

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,28 +1029,17 @@ constexpr auto dispatch_fmod_op(T elem1, T elem2)
10291029

10301030
static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
10311031
{
1032-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_INT][eft_INT] = {
1032+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_INT][eft_INT] = {
10331033
eft_INT, (void *)dpnp_copy_c_default<int32_t>};
1034-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_LNG][eft_LNG] = {
1034+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_LNG][eft_LNG] = {
10351035
eft_LNG, (void *)dpnp_copy_c_default<int64_t>};
1036-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_FLT][eft_FLT] = {
1036+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_FLT][eft_FLT] = {
10371037
eft_FLT, (void *)dpnp_copy_c_default<float>};
1038-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_DBL][eft_DBL] = {
1038+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_DBL][eft_DBL] = {
10391039
eft_DBL, (void *)dpnp_copy_c_default<double>};
1040-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE][eft_C128][eft_C128] = {
1040+
fmap[DPNPFuncName::DPNP_FN_CONJUGATE][eft_C128][eft_C128] = {
10411041
eft_C128, (void *)dpnp_conjugate_c_default<std::complex<double>>};
10421042

1043-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_INT][eft_INT] = {
1044-
eft_INT, (void *)dpnp_copy_c_ext<int32_t>};
1045-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_LNG][eft_LNG] = {
1046-
eft_LNG, (void *)dpnp_copy_c_ext<int64_t>};
1047-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_FLT][eft_FLT] = {
1048-
eft_FLT, (void *)dpnp_copy_c_ext<float>};
1049-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_DBL][eft_DBL] = {
1050-
eft_DBL, (void *)dpnp_copy_c_ext<double>};
1051-
fmap[DPNPFuncName::DPNP_FN_CONJIGUATE_EXT][eft_C128][eft_C128] = {
1052-
eft_C128, (void *)dpnp_conjugate_c_ext<std::complex<double>>};
1053-
10541043
fmap[DPNPFuncName::DPNP_FN_COPY][eft_BLN][eft_BLN] = {
10551044
eft_BLN, (void *)dpnp_copy_c_default<bool>};
10561045
fmap[DPNPFuncName::DPNP_FN_COPY][eft_INT][eft_INT] = {

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
7070
DPNP_FN_CHOLESKY_EXT
7171
DPNP_FN_CHOOSE
7272
DPNP_FN_CHOOSE_EXT
73-
DPNP_FN_CONJIGUATE
74-
DPNP_FN_CONJIGUATE_EXT
7573
DPNP_FN_COPY
7674
DPNP_FN_COPY_EXT
7775
DPNP_FN_COPYSIGN

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ __all__ += [
4040
"dpnp_arctan2",
4141
"dpnp_around",
4242
"dpnp_ceil",
43-
"dpnp_conjugate",
4443
"dpnp_copysign",
4544
"dpnp_cross",
4645
"dpnp_cumprod",
@@ -163,10 +162,6 @@ cpdef utils.dpnp_descriptor dpnp_ceil(utils.dpnp_descriptor x1, utils.dpnp_descr
163162
return call_fptr_1in_1out_strides(DPNP_FN_CEIL_EXT, x1, dtype=None, out=out, where=True, func_name='ceil')
164163

165164

166-
cpdef utils.dpnp_descriptor dpnp_conjugate(utils.dpnp_descriptor x1):
167-
return call_fptr_1in_1out_strides(DPNP_FN_CONJIGUATE_EXT, x1)
168-
169-
170165
cpdef utils.dpnp_descriptor dpnp_copysign(utils.dpnp_descriptor x1_obj,
171166
utils.dpnp_descriptor x2_obj,
172167
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"dpnp_bitwise_and",
4747
"dpnp_bitwise_or",
4848
"dpnp_bitwise_xor",
49+
"dpnp_conj",
4950
"dpnp_cos",
5051
"dpnp_divide",
5152
"dpnp_equal",
@@ -367,6 +368,58 @@ def _call_cos(src, dst, sycl_queue, depends=None):
367368
return dpnp_array._create_from_usm_ndarray(res_usm)
368369

369370

371+
_conj_docstring = """
372+
conj(x, out=None, order='K')
373+
374+
Computes conjugate for each element `x_i` for input array `x`.
375+
376+
Args:
377+
x (dpnp.ndarray):
378+
Input array, expected to have numeric data type.
379+
out ({None, dpnp.ndarray}, optional):
380+
Output array to populate. Array must have the correct
381+
shape and the expected data type.
382+
order ("C","F","A","K", optional): memory layout of the new
383+
output array, if parameter `out` is `None`.
384+
Default: "K".
385+
Return:
386+
dpnp.ndarray:
387+
An array containing the element-wise conjugate.
388+
The returned array has the same data type as `x`.
389+
"""
390+
391+
392+
def _call_conj(src, dst, sycl_queue, depends=None):
393+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
394+
395+
if depends is None:
396+
depends = []
397+
398+
if vmi._mkl_conj_to_call(sycl_queue, src, dst):
399+
# call pybind11 extension for conj() function from OneMKL VM
400+
return vmi._conj(sycl_queue, src, dst, depends)
401+
return ti._conj(src, dst, sycl_queue, depends)
402+
403+
404+
conj_func = UnaryElementwiseFunc(
405+
"conj", ti._conj_result_type, _call_conj, _conj_docstring
406+
)
407+
408+
409+
def dpnp_conj(x, out=None, order="K"):
410+
"""
411+
Invokes conj() function from pybind11 extension of OneMKL VM if possible.
412+
413+
Otherwise fully relies on dpctl.tensor implementation for conj() function.
414+
"""
415+
# dpctl.tensor only works with usm_ndarray
416+
x1_usm = dpnp.get_usm_ndarray(x)
417+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
418+
419+
res_usm = conj_func(x1_usm, out=out_usm, order=order)
420+
return dpnp_array._create_from_usm_ndarray(res_usm)
421+
422+
370423
_divide_docstring_ = """
371424
divide(x1, x2, out=None, order="K")
372425

dpnp/dpnp_iface_mathematical.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949

5050
from .dpnp_algo import *
5151
from .dpnp_algo.dpnp_elementwise_common import (
52+
check_nd_call_func,
5253
dpnp_add,
54+
dpnp_conj,
5355
dpnp_divide,
5456
dpnp_floor_divide,
5557
dpnp_multiply,
@@ -387,7 +389,17 @@ def ceil(x1, out=None, **kwargs):
387389
return call_origin(numpy.ceil, x1, out=out, **kwargs)
388390

389391

390-
def conjugate(x1, **kwargs):
392+
def conjugate(
393+
x,
394+
/,
395+
out=None,
396+
*,
397+
order="K",
398+
where=True,
399+
dtype=None,
400+
subok=True,
401+
**kwargs,
402+
):
391403
"""
392404
Return the complex conjugate, element-wise.
393405
@@ -396,6 +408,18 @@ def conjugate(x1, **kwargs):
396408
397409
For full documentation refer to :obj:`numpy.conjugate`.
398410
411+
Returns
412+
-------
413+
out : dpnp.ndarray
414+
The conjugate of each element of `x`.
415+
416+
Limitations
417+
-----------
418+
Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
419+
Parameters `where`, `dtype` and `subok` are supported with their default values.
420+
Otherwise the function will be executed sequentially on CPU.
421+
Input array data types are limited by supported DPNP :ref:`Data types`.
422+
399423
Examples
400424
--------
401425
>>> import dpnp as np
@@ -409,13 +433,17 @@ def conjugate(x1, **kwargs):
409433
410434
"""
411435

412-
x1_desc = dpnp.get_dpnp_descriptor(
413-
x1, copy_when_strides=False, copy_when_nondefault_queue=False
436+
return check_nd_call_func(
437+
numpy.conjugate,
438+
dpnp_conj,
439+
x,
440+
out=out,
441+
where=where,
442+
order=order,
443+
dtype=dtype,
444+
subok=subok,
445+
**kwargs,
414446
)
415-
if x1_desc and not kwargs:
416-
return dpnp_conjugate(x1_desc).get_pyobj()
417-
418-
return call_origin(numpy.conjugate, x1, **kwargs)
419447

420448

421449
conj = conjugate

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.
1313
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.float32)]
1414

1515
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-ceil-data1]
16-
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-conjugate-data2]
1716
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-copy-data3]
1817
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-cumprod-data4]
1918
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-cumsum-data5]
@@ -22,7 +21,6 @@ tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-fabs-data8]
2221
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-floor-data9]
2322

2423
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-ceil-data1]
25-
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-conjugate-data2]
2624
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-copy-data3]
2725
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-cumprod-data4]
2826
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-cumsum-data5]

0 commit comments

Comments
 (0)