Skip to content

Commit 0eb7a91

Browse files
committed
impl elementwise exp and sin
1 parent 82c3223 commit 0eb7a91

File tree

7 files changed

+703
-38
lines changed

7 files changed

+703
-38
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@
9797
cos,
9898
divide,
9999
equal,
100+
exp,
100101
isfinite,
101102
isinf,
102103
isnan,
103104
multiply,
105+
sin,
104106
sqrt,
105107
subtract,
106108
)
@@ -183,10 +185,12 @@
183185
"abs",
184186
"add",
185187
"cos",
188+
"sin",
186189
"isinf",
187190
"isnan",
188191
"isfinite",
189192
"sqrt",
193+
"exp",
190194
"divide",
191195
"multiply",
192196
"subtract",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,13 @@
150150
)
151151

152152
# U13: ==== EXP (x)
153-
# FIXME: implement U13
153+
_exp_docstring = """
154+
exp(x, order='K')
155+
156+
Computes exponential for each element `x_i` for input array `x`.
157+
"""
158+
159+
exp = UnaryElementwiseFunc("exp", ti._exp_result_type, ti._exp, _exp_docstring)
154160

155161
# U14: ==== EXPM1 (x)
156162
# FIXME: implement U14
@@ -282,7 +288,13 @@
282288
# FIXME: implement U29
283289

284290
# U30: ==== SIN (x)
285-
# FIXME: implement U30
291+
_sin_docstring = """
292+
sin(x, order='K')
293+
294+
Computes sin for each element `x_i` for input array `x`.
295+
"""
296+
297+
sin = UnaryElementwiseFunc("sin", ti._sin_result_type, ti._sin, _sin_docstring)
286298

287299
# U31: ==== SINH (x)
288300
# FIXME: implement U31
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
//=== exp.hpp - Unary function EXP ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of EXP(x) function.
23+
//===---------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <cmath>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "kernels/elementwise_functions/common.hpp"
33+
34+
#include "utils/offset_utils.hpp"
35+
#include "utils/type_dispatch.hpp"
36+
#include "utils/type_utils.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace exp
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
51+
using dpctl::tensor::type_utils::is_complex;
52+
53+
template <typename argT, typename resT> struct ExpFunctor
54+
{
55+
// is function constant for given argT
56+
using is_constant = typename std::false_type;
57+
// constant value, if constant
58+
// constexpr resT constant_value = resT{};
59+
// is function defined for sycl::vec
60+
using supports_vec = typename std::false_type;
61+
// do both argTy and resTy support sugroup store/load operation
62+
using supports_sg_loadstore = typename std::negation<
63+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
64+
65+
resT operator()(const argT &in)
66+
{
67+
return std::exp(in);
68+
}
69+
};
70+
71+
template <typename argTy,
72+
typename resTy = argTy,
73+
unsigned int vec_sz = 4,
74+
unsigned int n_vecs = 2>
75+
using ExpContigFunctor = elementwise_common::
76+
UnaryContigFunctor<argTy, resTy, ExpFunctor<argTy, resTy>, vec_sz, n_vecs>;
77+
78+
template <typename argTy, typename resTy, typename IndexerT>
79+
using ExpStridedFunctor = elementwise_common::
80+
UnaryStridedFunctor<argTy, resTy, IndexerT, ExpFunctor<argTy, resTy>>;
81+
82+
template <typename T> struct ExpOutputType
83+
{
84+
using value_type = typename std::disjunction< // disjunction is C++17
85+
// feature, supported by DPC++
86+
td_ns::TypeMapResultEntry<T, sycl::half>,
87+
td_ns::TypeMapResultEntry<T, float>,
88+
td_ns::TypeMapResultEntry<T, double>,
89+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
90+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
91+
td_ns::DefaultResultEntry<void>>::result_type;
92+
};
93+
94+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
95+
class exp_contig_kernel;
96+
97+
template <typename argTy>
98+
sycl::event exp_contig_impl(sycl::queue exec_q,
99+
size_t nelems,
100+
const char *arg_p,
101+
char *res_p,
102+
const std::vector<sycl::event> &depends = {})
103+
{
104+
return elementwise_common::unary_contig_impl<
105+
argTy, ExpOutputType, ExpContigFunctor, exp_contig_kernel>(
106+
exec_q, nelems, arg_p, res_p, depends);
107+
}
108+
109+
template <typename fnT, typename T> struct ExpContigFactory
110+
{
111+
fnT get()
112+
{
113+
if constexpr (std::is_same_v<typename ExpOutputType<T>::value_type,
114+
void>) {
115+
fnT fn = nullptr;
116+
return fn;
117+
}
118+
else {
119+
fnT fn = exp_contig_impl<T>;
120+
return fn;
121+
}
122+
}
123+
};
124+
125+
template <typename fnT, typename T> struct ExpTypeMapFactory
126+
{
127+
/*! @brief get typeid for output type of std::exp(T x) */
128+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
129+
{
130+
using rT = typename ExpOutputType<T>::value_type;
131+
return td_ns::GetTypeid<rT>{}.get();
132+
}
133+
};
134+
135+
template <typename T1, typename T2, typename T3> class exp_strided_kernel;
136+
137+
template <typename argTy>
138+
sycl::event exp_strided_impl(sycl::queue exec_q,
139+
size_t nelems,
140+
int nd,
141+
const py::ssize_t *shape_and_strides,
142+
const char *arg_p,
143+
py::ssize_t arg_offset,
144+
char *res_p,
145+
py::ssize_t res_offset,
146+
const std::vector<sycl::event> &depends,
147+
const std::vector<sycl::event> &additional_depends)
148+
{
149+
return elementwise_common::unary_strided_impl<
150+
argTy, ExpOutputType, ExpStridedFunctor, exp_strided_kernel>(
151+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
152+
res_offset, depends, additional_depends);
153+
}
154+
155+
template <typename fnT, typename T> struct ExpStridedFactory
156+
{
157+
fnT get()
158+
{
159+
if constexpr (std::is_same_v<typename ExpOutputType<T>::value_type,
160+
void>) {
161+
fnT fn = nullptr;
162+
return fn;
163+
}
164+
else {
165+
fnT fn = exp_strided_impl<T>;
166+
return fn;
167+
}
168+
}
169+
};
170+
171+
} // namespace exp
172+
} // namespace kernels
173+
} // namespace tensor
174+
} // namespace dpctl

0 commit comments

Comments
 (0)