Skip to content

Commit 40113bc

Browse files
Merge pull request #1205 from IntelPython/feature/elementwise-functions-sqrt
Implements dpctl.tensor.sqrt
2 parents d699a5f + 5d37d74 commit 40113bc

File tree

5 files changed

+410
-2
lines changed

5 files changed

+410
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
from dpctl.tensor._utility_functions import all, any
9292

9393
from ._constants import e, inf, nan, newaxis, pi
94-
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan
94+
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt
9595

9696
__all__ = [
9797
"Device",
@@ -174,4 +174,5 @@
174174
"isinf",
175175
"isnan",
176176
"isfinite",
177+
"sqrt",
177178
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,13 @@
7171
isinf = UnaryElementwiseFunc(
7272
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
7373
)
74+
75+
# SQRT
76+
77+
_sqrt_docstring_ = """
78+
Computes sqrt for each element `x_i` for input array `x`.
79+
"""
80+
81+
sqrt = UnaryElementwiseFunc(
82+
"sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_
83+
)
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#pragma once
2+
#include <CL/sycl.hpp>
3+
#include <cmath>
4+
#include <cstddef>
5+
#include <cstdint>
6+
#include <type_traits>
7+
8+
#include "kernels/elementwise_functions/common.hpp"
9+
10+
#include "utils/offset_utils.hpp"
11+
#include "utils/type_dispatch.hpp"
12+
#include "utils/type_utils.hpp"
13+
#include <pybind11/pybind11.h>
14+
15+
namespace dpctl
16+
{
17+
namespace tensor
18+
{
19+
namespace kernels
20+
{
21+
namespace sqrt
22+
{
23+
24+
namespace py = pybind11;
25+
namespace td_ns = dpctl::tensor::type_dispatch;
26+
27+
using dpctl::tensor::type_utils::is_complex;
28+
29+
template <typename argT, typename resT> struct SqrtFunctor
30+
{
31+
32+
// is function constant for given argT
33+
using is_constant = typename std::false_type;
34+
// constant value, if constant
35+
// constexpr resT constant_value = resT{};
36+
// is function defined for sycl::vec
37+
using supports_vec = typename std::false_type;
38+
// do both argTy and resTy support sugroup store/load operation
39+
using supports_sg_loadstore = typename std::negation<
40+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
41+
42+
resT operator()(const argT &in)
43+
{
44+
return std::sqrt(in);
45+
}
46+
};
47+
48+
template <typename argTy,
49+
typename resTy = argTy,
50+
unsigned int vec_sz = 4,
51+
unsigned int n_vecs = 2>
52+
using SqrtContigFunctor = elementwise_common::
53+
UnaryContigFunctor<argTy, resTy, SqrtFunctor<argTy, resTy>, vec_sz, n_vecs>;
54+
55+
template <typename argTy, typename resTy, typename IndexerT>
56+
using SqrtStridedFunctor = elementwise_common::
57+
UnaryStridedFunctor<argTy, resTy, IndexerT, SqrtFunctor<argTy, resTy>>;
58+
59+
template <typename T> struct SqrtOutputType
60+
{
61+
using value_type = typename std::disjunction< // disjunction is C++17
62+
// feature, supported by DPC++
63+
td_ns::TypeMapEntry<T, sycl::half, sycl::half>,
64+
td_ns::TypeMapEntry<T, float, float>,
65+
td_ns::TypeMapEntry<T, double, double>,
66+
td_ns::TypeMapEntry<T, std::complex<float>, std::complex<float>>,
67+
td_ns::TypeMapEntry<T, std::complex<double>, std::complex<double>>,
68+
td_ns::DefaultEntry<void>>::result_type;
69+
};
70+
71+
typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)(
72+
sycl::queue,
73+
size_t,
74+
const char *,
75+
char *,
76+
const std::vector<sycl::event> &);
77+
78+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
79+
class sqrt_contig_kernel;
80+
81+
template <typename argTy>
82+
sycl::event sqrt_contig_impl(sycl::queue exec_q,
83+
size_t nelems,
84+
const char *arg_p,
85+
char *res_p,
86+
const std::vector<sycl::event> &depends = {})
87+
{
88+
sycl::event sqrt_ev = exec_q.submit([&](sycl::handler &cgh) {
89+
cgh.depends_on(depends);
90+
constexpr size_t lws = 64;
91+
constexpr unsigned int vec_sz = 4;
92+
constexpr unsigned int n_vecs = 2;
93+
static_assert(lws % vec_sz == 0);
94+
auto gws_range = sycl::range<1>(
95+
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
96+
lws);
97+
auto lws_range = sycl::range<1>(lws);
98+
99+
using resTy = typename SqrtOutputType<argTy>::value_type;
100+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
101+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
102+
103+
cgh.parallel_for<
104+
class sqrt_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
105+
sycl::nd_range<1>(gws_range, lws_range),
106+
SqrtContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
107+
nelems));
108+
});
109+
return sqrt_ev;
110+
}
111+
112+
template <typename fnT, typename T> struct SqrtContigFactory
113+
{
114+
fnT get()
115+
{
116+
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
117+
void>) {
118+
fnT fn = nullptr;
119+
return fn;
120+
}
121+
else {
122+
fnT fn = sqrt_contig_impl<T>;
123+
return fn;
124+
}
125+
}
126+
};
127+
128+
template <typename fnT, typename T> struct SqrtTypeMapFactory
129+
{
130+
/*! @brief get typeid for output type of std::sqrt(T x) */
131+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
132+
{
133+
using rT = typename SqrtOutputType<T>::value_type;
134+
;
135+
return td_ns::GetTypeid<rT>{}.get();
136+
}
137+
};
138+
139+
template <typename T1, typename T2, typename T3> class sqrt_strided_kernel;
140+
141+
typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)(
142+
sycl::queue,
143+
size_t,
144+
int,
145+
const py::ssize_t *,
146+
const char *,
147+
py::ssize_t,
148+
char *,
149+
py::ssize_t,
150+
const std::vector<sycl::event> &,
151+
const std::vector<sycl::event> &);
152+
153+
template <typename argTy>
154+
sycl::event
155+
sqrt_strided_impl(sycl::queue exec_q,
156+
size_t nelems,
157+
int nd,
158+
const py::ssize_t *shape_and_strides,
159+
const char *arg_p,
160+
py::ssize_t arg_offset,
161+
char *res_p,
162+
py::ssize_t res_offset,
163+
const std::vector<sycl::event> &depends,
164+
const std::vector<sycl::event> &additional_depends)
165+
{
166+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
167+
cgh.depends_on(depends);
168+
cgh.depends_on(additional_depends);
169+
170+
using resTy = typename SqrtOutputType<argTy>::value_type;
171+
using IndexerT =
172+
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
173+
174+
IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);
175+
176+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
177+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
178+
179+
sycl::range<1> gRange{nelems};
180+
181+
cgh.parallel_for<sqrt_strided_kernel<argTy, resTy, IndexerT>>(
182+
gRange, SqrtStridedFunctor<argTy, resTy, IndexerT>(
183+
arg_tp, res_tp, arg_res_indexer));
184+
});
185+
return comp_ev;
186+
}
187+
188+
template <typename fnT, typename T> struct SqrtStridedFactory
189+
{
190+
fnT get()
191+
{
192+
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
193+
void>) {
194+
fnT fn = nullptr;
195+
return fn;
196+
}
197+
else {
198+
fnT fn = sqrt_strided_impl<T>;
199+
return fn;
200+
}
201+
}
202+
};
203+
204+
} // namespace sqrt
205+
} // namespace kernels
206+
} // namespace tensor
207+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "kernels/elementwise_functions/isfinite.hpp"
3939
#include "kernels/elementwise_functions/isinf.hpp"
4040
#include "kernels/elementwise_functions/isnan.hpp"
41+
#include "kernels/elementwise_functions/sqrt.hpp"
4142

4243
namespace dpctl
4344
{
@@ -342,6 +343,43 @@ void populate_add_dispatch_tables(void)
342343

343344
} // namespace impl
344345

346+
// SQRT
347+
namespace impl
348+
{
349+
350+
namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt;
351+
using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t;
352+
using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t;
353+
354+
static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types];
355+
static int sqrt_output_typeid_vector[td_ns::num_types];
356+
static sqrt_strided_impl_fn_ptr_t
357+
sqrt_strided_dispatch_vector[td_ns::num_types];
358+
359+
void populate_sqrt_dispatch_vectors(void)
360+
{
361+
using namespace td_ns;
362+
namespace fn_ns = sqrt_fn_ns;
363+
364+
using fn_ns::SqrtContigFactory;
365+
DispatchVectorBuilder<sqrt_contig_impl_fn_ptr_t, SqrtContigFactory,
366+
num_types>
367+
dvb1;
368+
dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector);
369+
370+
using fn_ns::SqrtStridedFactory;
371+
DispatchVectorBuilder<sqrt_strided_impl_fn_ptr_t, SqrtStridedFactory,
372+
num_types>
373+
dvb2;
374+
dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector);
375+
376+
using fn_ns::SqrtTypeMapFactory;
377+
DispatchVectorBuilder<int, SqrtTypeMapFactory, num_types> dvb3;
378+
dvb3.populate_dispatch_vector(sqrt_output_typeid_vector);
379+
}
380+
381+
} // namespace impl
382+
345383
namespace py = pybind11;
346384

347385
void init_elementwise_functions(py::module_ m)
@@ -649,7 +687,26 @@ void init_elementwise_functions(py::module_ m)
649687
// FIXME:
650688

651689
// U33: ==== SQRT (x)
652-
// FIXME:
690+
{
691+
impl::populate_sqrt_dispatch_vectors();
692+
using impl::sqrt_contig_dispatch_vector;
693+
using impl::sqrt_output_typeid_vector;
694+
using impl::sqrt_strided_dispatch_vector;
695+
696+
auto sqrt_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q,
697+
const event_vecT &depends = {}) {
698+
return py_unary_ufunc(
699+
src, dst, exec_q, depends, sqrt_output_typeid_vector,
700+
sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector);
701+
};
702+
m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"),
703+
py::arg("sycl_queue"), py::arg("depends") = py::list());
704+
705+
auto sqrt_result_type_pyapi = [&](py::dtype dtype) {
706+
return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector);
707+
};
708+
m.def("_sqrt_result_type", sqrt_result_type_pyapi);
709+
}
653710

654711
// B23: ==== SUBTRACT (x1, x2)
655712
// FIXME:

0 commit comments

Comments
 (0)