Skip to content

Commit 3a063f2

Browse files
authored
Merge pull request #1218 from IntelPython/elementwise-expm1-log1p-log
Implementation of expm1, log, and log1p
2 parents 521867b + 190e5d3 commit 3a063f2

File tree

9 files changed

+1325
-9
lines changed

9 files changed

+1325
-9
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,12 @@
9797
cos,
9898
divide,
9999
equal,
100+
expm1,
100101
isfinite,
101102
isinf,
102103
isnan,
104+
log,
105+
log1p,
103106
multiply,
104107
sqrt,
105108
subtract,
@@ -184,9 +187,12 @@
184187
"abs",
185188
"add",
186189
"cos",
190+
"expm1",
187191
"isinf",
188192
"isnan",
189193
"isfinite",
194+
"log",
195+
"log1p",
190196
"sqrt",
191197
"divide",
192198
"multiply",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,26 @@
153153
# FIXME: implement U13
154154

155155
# U14: ==== EXPM1 (x)
156-
# FIXME: implement U14
156+
_expm1_docstring = """
157+
expm1(x, out=None, order='K')
158+
Computes an approximation of exp(x)-1 element-wise.
159+
Args:
160+
x (usm_ndarray):
161+
Input array, expected to have numeric data type.
162+
out (usm_ndarray):
163+
Output array to populate. Array must have the correct
164+
shape and the expected data type.
165+
order ("C","F","A","K", optional): memory layout of the new
166+
output array, if parameter `out` is `None`.
167+
Default: "K".
168+
Return:
169+
usm_ndarray:
170+
An array containing the element-wise exp(x)-1 values.
171+
"""
172+
173+
expm1 = UnaryElementwiseFunc(
174+
"expm1", ti._expm1_result_type, ti._expm1, _expm1_docstring
175+
)
157176

158177
# U15: ==== FLOOR (x)
159178
# FIXME: implement U15
@@ -210,10 +229,46 @@
210229
# FIXME: implement B14
211230

212231
# U20: ==== LOG (x)
213-
# FIXME: implement U20
232+
_log_docstring = """
233+
log(x, out=None, order='K')
234+
Computes the natural logarithm element-wise.
235+
Args:
236+
x (usm_ndarray):
237+
Input array, expected to have numeric data type.
238+
out (usm_ndarray):
239+
Output array to populate. Array must have the correct
240+
shape and the expected data type.
241+
order ("C","F","A","K", optional): memory layout of the new
242+
output array, if parameter `out` is `None`.
243+
Default: "K".
244+
Return:
245+
usm_ndarray:
246+
An array containing the element-wise natural logarithm values.
247+
"""
248+
249+
log = UnaryElementwiseFunc("log", ti._log_result_type, ti._log, _log_docstring)
214250

215251
# U21: ==== LOG1P (x)
216-
# FIXME: implement U21
252+
_log1p_docstring = """
253+
log1p(x, out=None, order='K')
254+
Computes an approximation of log(1+x) element-wise.
255+
Args:
256+
x (usm_ndarray):
257+
Input array, expected to have numeric data type.
258+
out (usm_ndarray):
259+
Output array to populate. Array must have the correct
260+
shape and the expected data type.
261+
order ("C","F","A","K", optional): memory layout of the new
262+
output array, if parameter `out` is `None`.
263+
Default: "K".
264+
Return:
265+
usm_ndarray:
266+
An array containing the element-wise log(1+x) values.
267+
"""
268+
269+
log1p = UnaryElementwiseFunc(
270+
"log1p", ti._log1p_result_type, ti._log1p, _log1p_docstring
271+
)
217272

218273
# U22: ==== LOG2 (x)
219274
# FIXME: implement U22
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
//=== expm1.hpp - Unary function EXPM1 ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of EXPM1(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "kernels/elementwise_functions/common.hpp"
34+
35+
#include "utils/offset_utils.hpp"
36+
#include "utils/type_dispatch.hpp"
37+
#include "utils/type_utils.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace expm1
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
52+
using dpctl::tensor::type_utils::is_complex;
53+
54+
template <typename argT, typename resT> struct Expm1Functor
55+
{
56+
57+
// is function constant for given argT
58+
using is_constant = typename std::false_type;
59+
// constant value, if constant
60+
// constexpr resT constant_value = resT{};
61+
// is function defined for sycl::vec
62+
using supports_vec = typename std::false_type;
63+
// do both argTy and resTy support sugroup store/load operation
64+
using supports_sg_loadstore = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
66+
67+
resT operator()(const argT &in)
68+
{
69+
if constexpr (is_complex<argT>::value) {
70+
using realT = typename argT::value_type;
71+
// expm1(x + I*y) = expm1(x)*cos(y) - 2*sin(y / 2)^2 +
72+
// I*exp(x)*sin(y)
73+
const realT x = std::real(in);
74+
const realT y = std::imag(in);
75+
76+
realT cosY_val;
77+
const realT sinY_val = sycl::sincos(y, &cosY_val);
78+
const realT sinhalfY_val = std::sin(y / 2);
79+
80+
const realT res_re =
81+
std::expm1(x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
82+
const realT res_im = std::exp(x) * sinY_val;
83+
return resT{res_re, res_im};
84+
}
85+
else {
86+
static_assert(std::is_floating_point_v<argT> ||
87+
std::is_same_v<argT, sycl::half>);
88+
return std::expm1(in);
89+
}
90+
}
91+
};
92+
93+
template <typename argTy,
94+
typename resTy = argTy,
95+
unsigned int vec_sz = 4,
96+
unsigned int n_vecs = 2>
97+
using Expm1ContigFunctor =
98+
elementwise_common::UnaryContigFunctor<argTy,
99+
resTy,
100+
Expm1Functor<argTy, resTy>,
101+
vec_sz,
102+
n_vecs>;
103+
104+
template <typename argTy, typename resTy, typename IndexerT>
105+
using Expm1StridedFunctor = elementwise_common::
106+
UnaryStridedFunctor<argTy, resTy, IndexerT, Expm1Functor<argTy, resTy>>;
107+
108+
template <typename T> struct Expm1OutputType
109+
{
110+
using value_type = typename std::disjunction< // disjunction is C++17
111+
// feature, supported by DPC++
112+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
113+
td_ns::TypeMapResultEntry<T, float, float>,
114+
td_ns::TypeMapResultEntry<T, double, double>,
115+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
116+
td_ns::
117+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
118+
td_ns::DefaultResultEntry<void>>::result_type;
119+
};
120+
121+
typedef sycl::event (*expm1_contig_impl_fn_ptr_t)(
122+
sycl::queue,
123+
size_t,
124+
const char *,
125+
char *,
126+
const std::vector<sycl::event> &);
127+
128+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
129+
class expm1_contig_kernel;
130+
131+
template <typename argTy>
132+
sycl::event expm1_contig_impl(sycl::queue exec_q,
133+
size_t nelems,
134+
const char *arg_p,
135+
char *res_p,
136+
const std::vector<sycl::event> &depends = {})
137+
{
138+
sycl::event expm1_ev = exec_q.submit([&](sycl::handler &cgh) {
139+
cgh.depends_on(depends);
140+
constexpr size_t lws = 64;
141+
constexpr unsigned int vec_sz = 4;
142+
constexpr unsigned int n_vecs = 2;
143+
static_assert(lws % vec_sz == 0);
144+
auto gws_range = sycl::range<1>(
145+
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
146+
lws);
147+
auto lws_range = sycl::range<1>(lws);
148+
149+
using resTy = typename Expm1OutputType<argTy>::value_type;
150+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
151+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
152+
153+
cgh.parallel_for<
154+
class expm1_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
155+
sycl::nd_range<1>(gws_range, lws_range),
156+
Expm1ContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
157+
nelems));
158+
});
159+
return expm1_ev;
160+
}
161+
162+
template <typename fnT, typename T> struct Expm1ContigFactory
163+
{
164+
fnT get()
165+
{
166+
if constexpr (std::is_same_v<typename Expm1OutputType<T>::value_type,
167+
void>) {
168+
fnT fn = nullptr;
169+
return fn;
170+
}
171+
else {
172+
fnT fn = expm1_contig_impl<T>;
173+
return fn;
174+
}
175+
}
176+
};
177+
178+
template <typename fnT, typename T> struct Expm1TypeMapFactory
179+
{
180+
/*! @brief get typeid for output type of std::expm1(T x) */
181+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
182+
{
183+
using rT = typename Expm1OutputType<T>::value_type;
184+
;
185+
return td_ns::GetTypeid<rT>{}.get();
186+
}
187+
};
188+
189+
template <typename T1, typename T2, typename T3> class expm1_strided_kernel;
190+
191+
typedef sycl::event (*expm1_strided_impl_fn_ptr_t)(
192+
sycl::queue,
193+
size_t,
194+
int,
195+
const py::ssize_t *,
196+
const char *,
197+
py::ssize_t,
198+
char *,
199+
py::ssize_t,
200+
const std::vector<sycl::event> &,
201+
const std::vector<sycl::event> &);
202+
203+
template <typename argTy>
204+
sycl::event
205+
expm1_strided_impl(sycl::queue exec_q,
206+
size_t nelems,
207+
int nd,
208+
const py::ssize_t *shape_and_strides,
209+
const char *arg_p,
210+
py::ssize_t arg_offset,
211+
char *res_p,
212+
py::ssize_t res_offset,
213+
const std::vector<sycl::event> &depends,
214+
const std::vector<sycl::event> &additional_depends)
215+
{
216+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
217+
cgh.depends_on(depends);
218+
cgh.depends_on(additional_depends);
219+
220+
using resTy = typename Expm1OutputType<argTy>::value_type;
221+
using IndexerT =
222+
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
223+
224+
IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);
225+
226+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
227+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
228+
229+
sycl::range<1> gRange{nelems};
230+
231+
cgh.parallel_for<expm1_strided_kernel<argTy, resTy, IndexerT>>(
232+
gRange, Expm1StridedFunctor<argTy, resTy, IndexerT>(
233+
arg_tp, res_tp, arg_res_indexer));
234+
});
235+
return comp_ev;
236+
}
237+
238+
template <typename fnT, typename T> struct Expm1StridedFactory
239+
{
240+
fnT get()
241+
{
242+
if constexpr (std::is_same_v<typename Expm1OutputType<T>::value_type,
243+
void>) {
244+
fnT fn = nullptr;
245+
return fn;
246+
}
247+
else {
248+
fnT fn = expm1_strided_impl<T>;
249+
return fn;
250+
}
251+
}
252+
};
253+
254+
} // namespace expm1
255+
} // namespace kernels
256+
} // namespace tensor
257+
} // namespace dpctl

0 commit comments

Comments
 (0)