Skip to content

Commit 12e6214

Browse files
committed
Separated implementation of mul, pow, rint, sin and sinh functions
1 parent cfeb9f5 commit 12e6214

File tree

13 files changed

+793
-500
lines changed

13 files changed

+793
-500
lines changed

dpnp/backend/extensions/vm/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ set(_elementwise_sources
4848
${CMAKE_CURRENT_SOURCE_DIR}/log10.cpp
4949
${CMAKE_CURRENT_SOURCE_DIR}/log1p.cpp
5050
${CMAKE_CURRENT_SOURCE_DIR}/log2.cpp
51+
${CMAKE_CURRENT_SOURCE_DIR}/mul.cpp
52+
${CMAKE_CURRENT_SOURCE_DIR}/pow.cpp
53+
${CMAKE_CURRENT_SOURCE_DIR}/rint.cpp
54+
${CMAKE_CURRENT_SOURCE_DIR}/sin.cpp
55+
${CMAKE_CURRENT_SOURCE_DIR}/sinh.cpp
5156
)
5257

5358
set(_module_src

dpnp/backend/extensions/vm/mul.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include <oneapi/mkl.hpp>
27+
#include <sycl/sycl.hpp>
28+
29+
#include "dpctl4pybind11.hpp"
30+
31+
#include "common.hpp"
32+
#include "mul.hpp"
33+
34+
// include a local copy of elementwise common header from dpctl tensor:
35+
// dpctl/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp
36+
// TODO: replace by including dpctl header once available
37+
#include "../elementwise_functions/elementwise_functions.hpp"
38+
39+
// dpctl tensor headers
40+
#include "kernels/elementwise_functions/common.hpp"
41+
#include "utils/type_dispatch.hpp"
42+
#include "utils/type_utils.hpp"
43+
44+
namespace dpnp::extensions::vm
45+
{
46+
namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common;
47+
namespace py = pybind11;
48+
namespace py_int = dpnp::extensions::py_internal;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
namespace vm_ext = dpnp::backend::ext::vm;
52+
53+
namespace impl
54+
{
55+
// OneMKL namespace with VM functions
56+
namespace mkl_vm = oneapi::mkl::vm;
57+
58+
/**
59+
* @brief A factory to define pairs of supported types for which
60+
* MKL VM library provides support in oneapi::mkl::vm::mul<T> function.
61+
*
62+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
63+
*/
64+
template <typename T1, typename T2>
65+
struct OutputType
66+
{
67+
using value_type = typename std::disjunction<
68+
td_ns::BinaryTypeMapResultEntry<T1,
69+
std::complex<double>,
70+
T2,
71+
std::complex<double>,
72+
std::complex<double>>,
73+
td_ns::BinaryTypeMapResultEntry<T1,
74+
std::complex<float>,
75+
T2,
76+
std::complex<float>,
77+
std::complex<float>>,
78+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
79+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
80+
td_ns::DefaultResultEntry<void>>::result_type;
81+
};
82+
83+
template <typename T1, typename T2>
84+
static sycl::event mul_contig_impl(sycl::queue &exec_q,
85+
std::size_t in_n,
86+
const char *in_a,
87+
ssize_t a_offset,
88+
const char *in_b,
89+
ssize_t b_offset,
90+
char *out_y,
91+
ssize_t out_offset,
92+
const std::vector<sycl::event> &depends)
93+
{
94+
tu_ns::validate_type_for_device<T1>(exec_q);
95+
tu_ns::validate_type_for_device<T2>(exec_q);
96+
97+
if ((a_offset != 0) || (b_offset != 0) || (out_offset != 0)) {
98+
throw std::runtime_error("Arrays offsets have to be equals to 0");
99+
}
100+
101+
std::int64_t n = static_cast<std::int64_t>(in_n);
102+
const T1 *a = reinterpret_cast<const T1 *>(in_a);
103+
const T2 *b = reinterpret_cast<const T2 *>(in_b);
104+
105+
using resTy = typename OutputType<T1, T2>::value_type;
106+
resTy *y = reinterpret_cast<resTy *>(out_y);
107+
108+
return mkl_vm::mul(exec_q,
109+
n, // number of elements to be calculated
110+
a, // pointer `a` containing 1st input vector of size n
111+
b, // pointer `b` containing 2nd input vector of size n
112+
y, // pointer `y` to the output vector of size n
113+
depends);
114+
}
115+
116+
using ew_cmn_ns::binary_contig_impl_fn_ptr_t;
117+
using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t;
118+
using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t;
119+
using ew_cmn_ns::binary_strided_impl_fn_ptr_t;
120+
121+
static int output_typeid_vector[td_ns::num_types][td_ns::num_types];
122+
static binary_contig_impl_fn_ptr_t contig_dispatch_vector[td_ns::num_types]
123+
[td_ns::num_types];
124+
125+
MACRO_POPULATE_DISPATCH_TABLES(mul);
126+
} // namespace impl
127+
128+
void init_mul(py::module_ m)
129+
{
130+
using arrayT = dpctl::tensor::usm_ndarray;
131+
using event_vecT = std::vector<sycl::event>;
132+
133+
impl::populate_dispatch_tables();
134+
using impl::contig_dispatch_vector;
135+
using impl::output_typeid_vector;
136+
137+
auto mul_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
138+
arrayT dst, const event_vecT &depends = {}) {
139+
return py_int::py_binary_ufunc(
140+
src1, src2, dst, exec_q, depends, output_typeid_vector,
141+
contig_dispatch_vector,
142+
// no support of strided implementation in OneMKL
143+
td_ns::NullPtrTable<impl::binary_strided_impl_fn_ptr_t>{},
144+
// no support of C-contig row with broadcasting in OneMKL
145+
td_ns::NullPtrTable<
146+
impl::
147+
binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{},
148+
td_ns::NullPtrTable<
149+
impl::
150+
binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{});
151+
};
152+
m.def("_mul", mul_pyapi,
153+
"Call `mul` function from OneMKL VM library to performs element "
154+
"by element multiplication of vector `src1` by vector `src2` "
155+
"to resulting vector `dst`",
156+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
157+
py::arg("dst"), py::arg("depends") = py::list());
158+
159+
auto mul_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
160+
arrayT src2, arrayT dst) {
161+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
162+
output_typeid_vector,
163+
contig_dispatch_vector);
164+
};
165+
m.def("_mkl_mul_to_call", mul_need_to_call_pyapi,
166+
"Check input arguments to answer if `mul` function from "
167+
"OneMKL VM library can be used",
168+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
169+
py::arg("dst"));
170+
}
171+
} // namespace dpnp::extensions::vm

dpnp/backend/extensions/vm/mul.hpp

Lines changed: 5 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,58 +25,11 @@
2525

2626
#pragma once
2727

28-
#include <CL/sycl.hpp>
28+
#include <pybind11/pybind11.h>
2929

30-
#include "common.hpp"
31-
#include "types_matrix.hpp"
30+
namespace py = pybind11;
3231

33-
namespace dpnp
32+
namespace dpnp::extensions::vm
3433
{
35-
namespace backend
36-
{
37-
namespace ext
38-
{
39-
namespace vm
40-
{
41-
template <typename T>
42-
sycl::event mul_contig_impl(sycl::queue exec_q,
43-
const std::int64_t n,
44-
const char *in_a,
45-
const char *in_b,
46-
char *out_y,
47-
const std::vector<sycl::event> &depends)
48-
{
49-
type_utils::validate_type_for_device<T>(exec_q);
50-
51-
const T *a = reinterpret_cast<const T *>(in_a);
52-
const T *b = reinterpret_cast<const T *>(in_b);
53-
using resTy = typename types::MulOutputType<T>::value_type;
54-
resTy *y = reinterpret_cast<resTy *>(out_y);
55-
56-
return mkl_vm::mul(exec_q,
57-
n, // number of elements to be calculated
58-
a, // pointer `a` containing 1st input vector of size n
59-
b, // pointer `b` containing 2nd input vector of size n
60-
y, // pointer `y` to the output vector of size n
61-
depends);
62-
}
63-
64-
template <typename fnT, typename T>
65-
struct MulContigFactory
66-
{
67-
fnT get()
68-
{
69-
if constexpr (std::is_same_v<
70-
typename types::MulOutputType<T>::value_type, void>)
71-
{
72-
return nullptr;
73-
}
74-
else {
75-
return mul_contig_impl<T>;
76-
}
77-
}
78-
};
79-
} // namespace vm
80-
} // namespace ext
81-
} // namespace backend
82-
} // namespace dpnp
34+
void init_mul(py::module_ m);
35+
} // namespace dpnp::extensions::vm

0 commit comments

Comments
 (0)