Skip to content

Commit d553611

Browse files
authored
Preparation to reuse future common dpctl f/w in functions from vm extension (#1868)
* Preparation to reuse common dpctl f/w for VM functions * PoC to decouple abs implementation to separate source file * Reuse typedef for function poiter from dpctl.tensor * Define populating vectors by a separate macro * Move implementation of utility functions from headers to source to resolve link issues * Separated implementation of acos function * Separated implementation of acosh function * Use function to simplify strides from dpctl tensor headers * PoC to decouple add implementation to separate source file * Separated implementation of asin function * Separated implementation of asinh function * Separated implementation of atan, atan2, atanh functions * Resolve issue with calling MKL function for undefined types * Separated implementation of cbrt, ceil, conj, cos and cosh functions * Separated implementation of div, exp, exp2, expm1, floor and hypot functions * Separated implementation of ln, log1p, log2 and log10 functions * Separated implementation of mul, pow, rint, sin and sinh functions * Separated implementation of sqr, sqrt, sub, tan, tanh and trunc functions * Removed unused header with types matrix * Remove unused functions * Use passing by reference in unary and binary funcs
1 parent 896209a commit d553611

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+6629
-3681
lines changed

dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp

Lines changed: 824 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include "dpctl4pybind11.hpp"
27+
28+
#include <pybind11/numpy.h>
29+
#include <pybind11/pybind11.h>
30+
#include <sycl/sycl.hpp>
31+
32+
#include "elementwise_functions_type_utils.hpp"
33+
34+
// dpctl tensor headers
35+
#include "utils/type_dispatch.hpp"
36+
37+
namespace py = pybind11;
38+
namespace td_ns = dpctl::tensor::type_dispatch;
39+
40+
namespace dpnp::extensions::py_internal::type_utils
41+
{
42+
py::dtype _dtype_from_typenum(td_ns::typenum_t dst_typenum_t)
43+
{
44+
switch (dst_typenum_t) {
45+
case td_ns::typenum_t::BOOL:
46+
return py::dtype("?");
47+
case td_ns::typenum_t::INT8:
48+
return py::dtype("i1");
49+
case td_ns::typenum_t::UINT8:
50+
return py::dtype("u1");
51+
case td_ns::typenum_t::INT16:
52+
return py::dtype("i2");
53+
case td_ns::typenum_t::UINT16:
54+
return py::dtype("u2");
55+
case td_ns::typenum_t::INT32:
56+
return py::dtype("i4");
57+
case td_ns::typenum_t::UINT32:
58+
return py::dtype("u4");
59+
case td_ns::typenum_t::INT64:
60+
return py::dtype("i8");
61+
case td_ns::typenum_t::UINT64:
62+
return py::dtype("u8");
63+
case td_ns::typenum_t::HALF:
64+
return py::dtype("f2");
65+
case td_ns::typenum_t::FLOAT:
66+
return py::dtype("f4");
67+
case td_ns::typenum_t::DOUBLE:
68+
return py::dtype("f8");
69+
case td_ns::typenum_t::CFLOAT:
70+
return py::dtype("c8");
71+
case td_ns::typenum_t::CDOUBLE:
72+
return py::dtype("c16");
73+
default:
74+
throw py::value_error("Unrecognized dst_typeid");
75+
}
76+
}
77+
78+
int _result_typeid(int arg_typeid, const int *fn_output_id)
79+
{
80+
if (arg_typeid < 0 || arg_typeid >= td_ns::num_types) {
81+
throw py::value_error("Input typeid " + std::to_string(arg_typeid) +
82+
" is outside of expected bounds.");
83+
}
84+
85+
return fn_output_id[arg_typeid];
86+
}
87+
} // namespace dpnp::extensions::py_internal::type_utils
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include "dpctl4pybind11.hpp"
29+
#include <pybind11/numpy.h>
30+
#include <pybind11/pybind11.h>
31+
#include <pybind11/stl.h>
32+
33+
// dpctl tensor headers
34+
#include "utils/type_dispatch.hpp"
35+
36+
namespace py = pybind11;
37+
namespace td_ns = dpctl::tensor::type_dispatch;
38+
39+
namespace dpnp::extensions::py_internal::type_utils
40+
{
41+
/*! @brief Produce dtype from a type number */
42+
extern py::dtype _dtype_from_typenum(td_ns::typenum_t);
43+
44+
/*! @brief Lookup typeid of the result from typeid of
45+
* argument and the mapping table */
46+
extern int _result_typeid(int, const int *);
47+
} // namespace dpnp::extensions::py_internal::type_utils
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2024, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#include "dpctl4pybind11.hpp"
27+
28+
#include <pybind11/pybind11.h>
29+
#include <vector>
30+
31+
#include "simplify_iteration_space.hpp"
32+
33+
// dpctl tensor headers
34+
#include "utils/strided_iters.hpp"
35+
36+
namespace dpnp::extensions::py_internal
37+
{
38+
namespace py = pybind11;
39+
namespace st_ns = dpctl::tensor::strides;
40+
41+
void simplify_iteration_space(int &nd,
42+
const py::ssize_t *const &shape,
43+
std::vector<py::ssize_t> const &src_strides,
44+
std::vector<py::ssize_t> const &dst_strides,
45+
// output
46+
std::vector<py::ssize_t> &simplified_shape,
47+
std::vector<py::ssize_t> &simplified_src_strides,
48+
std::vector<py::ssize_t> &simplified_dst_strides,
49+
py::ssize_t &src_offset,
50+
py::ssize_t &dst_offset)
51+
{
52+
if (nd > 1) {
53+
// Simplify iteration space to reduce dimensionality
54+
// and improve access pattern
55+
simplified_shape.reserve(nd);
56+
simplified_shape.insert(std::begin(simplified_shape), shape,
57+
shape + nd);
58+
assert(simplified_shape.size() == static_cast<size_t>(nd));
59+
60+
simplified_src_strides.reserve(nd);
61+
simplified_src_strides.insert(std::end(simplified_src_strides),
62+
std::begin(src_strides),
63+
std::end(src_strides));
64+
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
65+
66+
simplified_dst_strides.reserve(nd);
67+
simplified_dst_strides.insert(std::end(simplified_dst_strides),
68+
std::begin(dst_strides),
69+
std::end(dst_strides));
70+
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
71+
72+
int contracted_nd = st_ns::simplify_iteration_two_strides(
73+
nd, simplified_shape.data(), simplified_src_strides.data(),
74+
simplified_dst_strides.data(),
75+
src_offset, // modified by reference
76+
dst_offset // modified by reference
77+
);
78+
simplified_shape.resize(contracted_nd);
79+
simplified_src_strides.resize(contracted_nd);
80+
simplified_dst_strides.resize(contracted_nd);
81+
82+
nd = contracted_nd;
83+
}
84+
else if (nd == 1) {
85+
src_offset = 0;
86+
dst_offset = 0;
87+
// Populate vectors
88+
simplified_shape.reserve(nd);
89+
simplified_shape.push_back(shape[0]);
90+
assert(simplified_shape.size() == static_cast<size_t>(nd));
91+
92+
simplified_src_strides.reserve(nd);
93+
simplified_dst_strides.reserve(nd);
94+
95+
if (src_strides[0] < 0 && dst_strides[0] < 0) {
96+
simplified_src_strides.push_back(-src_strides[0]);
97+
simplified_dst_strides.push_back(-dst_strides[0]);
98+
if (shape[0] > 1) {
99+
src_offset += (shape[0] - 1) * src_strides[0];
100+
dst_offset += (shape[0] - 1) * dst_strides[0];
101+
}
102+
}
103+
else {
104+
simplified_src_strides.push_back(src_strides[0]);
105+
simplified_dst_strides.push_back(dst_strides[0]);
106+
}
107+
108+
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
109+
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
110+
}
111+
}
112+
113+
void simplify_iteration_space_3(
114+
int &nd,
115+
const py::ssize_t *const &shape,
116+
// src1
117+
std::vector<py::ssize_t> const &src1_strides,
118+
// src2
119+
std::vector<py::ssize_t> const &src2_strides,
120+
// dst
121+
std::vector<py::ssize_t> const &dst_strides,
122+
// output
123+
std::vector<py::ssize_t> &simplified_shape,
124+
std::vector<py::ssize_t> &simplified_src1_strides,
125+
std::vector<py::ssize_t> &simplified_src2_strides,
126+
std::vector<py::ssize_t> &simplified_dst_strides,
127+
py::ssize_t &src1_offset,
128+
py::ssize_t &src2_offset,
129+
py::ssize_t &dst_offset)
130+
{
131+
if (nd > 1) {
132+
// Simplify iteration space to reduce dimensionality
133+
// and improve access pattern
134+
simplified_shape.reserve(nd);
135+
simplified_shape.insert(std::end(simplified_shape), shape, shape + nd);
136+
assert(simplified_shape.size() == static_cast<size_t>(nd));
137+
138+
simplified_src1_strides.reserve(nd);
139+
simplified_src1_strides.insert(std::end(simplified_src1_strides),
140+
std::begin(src1_strides),
141+
std::end(src1_strides));
142+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
143+
144+
simplified_src2_strides.reserve(nd);
145+
simplified_src2_strides.insert(std::end(simplified_src2_strides),
146+
std::begin(src2_strides),
147+
std::end(src2_strides));
148+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
149+
150+
simplified_dst_strides.reserve(nd);
151+
simplified_dst_strides.insert(std::end(simplified_dst_strides),
152+
std::begin(dst_strides),
153+
std::end(dst_strides));
154+
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
155+
156+
int contracted_nd = st_ns::simplify_iteration_three_strides(
157+
nd, simplified_shape.data(), simplified_src1_strides.data(),
158+
simplified_src2_strides.data(), simplified_dst_strides.data(),
159+
src1_offset, // modified by reference
160+
src2_offset, // modified by reference
161+
dst_offset // modified by reference
162+
);
163+
simplified_shape.resize(contracted_nd);
164+
simplified_src1_strides.resize(contracted_nd);
165+
simplified_src2_strides.resize(contracted_nd);
166+
simplified_dst_strides.resize(contracted_nd);
167+
168+
nd = contracted_nd;
169+
}
170+
else if (nd == 1) {
171+
src1_offset = 0;
172+
src2_offset = 0;
173+
dst_offset = 0;
174+
// Populate vectors
175+
simplified_shape.reserve(nd);
176+
simplified_shape.push_back(shape[0]);
177+
assert(simplified_shape.size() == static_cast<size_t>(nd));
178+
179+
simplified_src1_strides.reserve(nd);
180+
simplified_src2_strides.reserve(nd);
181+
simplified_dst_strides.reserve(nd);
182+
183+
if ((src1_strides[0] < 0) && (src2_strides[0] < 0) &&
184+
(dst_strides[0] < 0)) {
185+
simplified_src1_strides.push_back(-src1_strides[0]);
186+
simplified_src2_strides.push_back(-src2_strides[0]);
187+
simplified_dst_strides.push_back(-dst_strides[0]);
188+
if (shape[0] > 1) {
189+
src1_offset += src1_strides[0] * (shape[0] - 1);
190+
src2_offset += src2_strides[0] * (shape[0] - 1);
191+
dst_offset += dst_strides[0] * (shape[0] - 1);
192+
}
193+
}
194+
else {
195+
simplified_src1_strides.push_back(src1_strides[0]);
196+
simplified_src2_strides.push_back(src2_strides[0]);
197+
simplified_dst_strides.push_back(dst_strides[0]);
198+
}
199+
200+
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
201+
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
202+
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
203+
}
204+
}
205+
} // namespace dpnp::extensions::py_internal

0 commit comments

Comments
 (0)