Skip to content

Commit 4952132

Browse files
committed
Implements dpctl.tensor.sqrt
1 parent b49744a commit 4952132

File tree

4 files changed

+277
-2
lines changed

4 files changed

+277
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
from dpctl.tensor._usmarray import usm_ndarray
9191

9292
from ._constants import e, inf, nan, newaxis, pi
93-
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan
93+
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt
9494

9595
__all__ = [
9696
"Device",
@@ -171,4 +171,5 @@
171171
"isinf",
172172
"isnan",
173173
"isfinite",
174+
"sqrt",
174175
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,13 @@
7171
isinf = UnaryElementwiseFunc(
7272
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
7373
)
74+
75+
# SQRT
76+
77+
_sqrt_docstring_ = """
78+
Computes sqrt for each element `x_i` for input array `x`.
79+
"""
80+
81+
sqrt = UnaryElementwiseFunc(
82+
"sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_
83+
)
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#pragma once
2+
#include <CL/sycl.hpp>
3+
#include <cmath>
4+
#include <cstddef>
5+
#include <cstdint>
6+
#include <type_traits>
7+
8+
#include "kernels/elementwise_functions/common.hpp"
9+
10+
#include "utils/offset_utils.hpp"
11+
#include "utils/type_dispatch.hpp"
12+
#include "utils/type_utils.hpp"
13+
#include <pybind11/pybind11.h>
14+
15+
namespace dpctl
16+
{
17+
namespace tensor
18+
{
19+
namespace kernels
20+
{
21+
namespace sqrt
22+
{
23+
24+
namespace py = pybind11;
25+
namespace td_ns = dpctl::tensor::type_dispatch;
26+
27+
using dpctl::tensor::type_utils::is_complex;
28+
29+
template <typename argT, typename resT> struct SqrtFunctor
30+
{
31+
32+
// is function constant for given argT
33+
using is_constant = typename std::false_type;
34+
// constant value, if constant
35+
// constexpr resT constant_value = resT{};
36+
// is function defined for sycl::vec
37+
using supports_vec = typename std::false_type;
38+
// do both argTy and resTy support sugroup store/load operation
39+
using supports_sg_loadstore = typename std::negation<
40+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
41+
42+
resT operator()(const argT &in)
43+
{
44+
return std::sqrt(in);
45+
}
46+
};
47+
48+
template <typename argTy,
49+
typename resTy = argTy,
50+
unsigned int vec_sz = 4,
51+
unsigned int n_vecs = 2>
52+
using SqrtContigFunctor = elementwise_common::
53+
UnaryContigFunctor<argTy, resTy, SqrtFunctor<argTy, resTy>, vec_sz, n_vecs>;
54+
55+
template <typename argTy, typename resTy, typename IndexerT>
56+
using SqrtStridedFunctor = elementwise_common::
57+
UnaryStridedFunctor<argTy, resTy, IndexerT, SqrtFunctor<argTy, resTy>>;
58+
59+
template <typename T> struct SqrtOutputType
60+
{
61+
using value_type = typename std::disjunction< // disjunction is C++17
62+
// feature, supported by DPC++
63+
td_ns::TypeMapEntry<T, sycl::half, sycl::half>,
64+
td_ns::TypeMapEntry<T, float, float>,
65+
td_ns::TypeMapEntry<T, double, double>,
66+
td_ns::TypeMapEntry<T, std::complex<float>, std::complex<float>>,
67+
td_ns::TypeMapEntry<T, std::complex<double>, std::complex<double>>,
68+
td_ns::DefaultEntry<void>>::result_type;
69+
};
70+
71+
typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)(
72+
sycl::queue,
73+
size_t,
74+
const char *,
75+
char *,
76+
const std::vector<sycl::event> &);
77+
78+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
79+
class sqrt_contig_kernel;
80+
81+
template <typename argTy>
82+
sycl::event sqrt_contig_impl(sycl::queue exec_q,
83+
size_t nelems,
84+
const char *arg_p,
85+
char *res_p,
86+
const std::vector<sycl::event> &depends = {})
87+
{
88+
sycl::event sqrt_ev = exec_q.submit([&](sycl::handler &cgh) {
89+
cgh.depends_on(depends);
90+
constexpr size_t lws = 64;
91+
constexpr unsigned int vec_sz = 4;
92+
constexpr unsigned int n_vecs = 2;
93+
static_assert(lws % vec_sz == 0);
94+
auto gws_range = sycl::range<1>(
95+
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
96+
lws);
97+
auto lws_range = sycl::range<1>(lws);
98+
99+
using resTy = typename SqrtOutputType<argTy>::value_type;
100+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
101+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
102+
103+
cgh.parallel_for<
104+
class sqrt_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
105+
sycl::nd_range<1>(gws_range, lws_range),
106+
SqrtContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
107+
nelems));
108+
});
109+
return sqrt_ev;
110+
}
111+
112+
template <typename fnT, typename T> struct SqrtContigFactory
113+
{
114+
fnT get()
115+
{
116+
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
117+
void>) {
118+
fnT fn = nullptr;
119+
return fn;
120+
}
121+
else {
122+
fnT fn = sqrt_contig_impl<T>;
123+
return fn;
124+
}
125+
}
126+
};
127+
128+
template <typename fnT, typename T> struct SqrtTypeMapFactory
129+
{
130+
/*! @brief get typeid for output type of std::sqrt(T x) */
131+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
132+
{
133+
using rT = typename SqrtOutputType<T>::value_type;
134+
;
135+
return td_ns::GetTypeid<rT>{}.get();
136+
}
137+
};
138+
139+
template <typename T1, typename T2, typename T3> class sqrt_strided_kernel;
140+
141+
typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)(
142+
sycl::queue,
143+
size_t,
144+
int,
145+
const py::ssize_t *,
146+
const char *,
147+
py::ssize_t,
148+
char *,
149+
py::ssize_t,
150+
const std::vector<sycl::event> &,
151+
const std::vector<sycl::event> &);
152+
153+
template <typename argTy>
154+
sycl::event
155+
sqrt_strided_impl(sycl::queue exec_q,
156+
size_t nelems,
157+
int nd,
158+
const py::ssize_t *shape_and_strides,
159+
const char *arg_p,
160+
py::ssize_t arg_offset,
161+
char *res_p,
162+
py::ssize_t res_offset,
163+
const std::vector<sycl::event> &depends,
164+
const std::vector<sycl::event> &additional_depends)
165+
{
166+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
167+
cgh.depends_on(depends);
168+
cgh.depends_on(additional_depends);
169+
170+
using resTy = typename SqrtOutputType<argTy>::value_type;
171+
using IndexerT =
172+
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
173+
174+
IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);
175+
176+
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
177+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
178+
179+
sycl::range<1> gRange{nelems};
180+
181+
cgh.parallel_for<sqrt_strided_kernel<argTy, resTy, IndexerT>>(
182+
gRange, SqrtStridedFunctor<argTy, resTy, IndexerT>(
183+
arg_tp, res_tp, arg_res_indexer));
184+
});
185+
return comp_ev;
186+
}
187+
188+
template <typename fnT, typename T> struct SqrtStridedFactory
189+
{
190+
fnT get()
191+
{
192+
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
193+
void>) {
194+
fnT fn = nullptr;
195+
return fn;
196+
}
197+
else {
198+
fnT fn = sqrt_strided_impl<T>;
199+
return fn;
200+
}
201+
}
202+
};
203+
204+
} // namespace sqrt
205+
} // namespace kernels
206+
} // namespace tensor
207+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "kernels/elementwise_functions/isfinite.hpp"
3939
#include "kernels/elementwise_functions/isinf.hpp"
4040
#include "kernels/elementwise_functions/isnan.hpp"
41+
#include "kernels/elementwise_functions/sqrt.hpp"
4142

4243
namespace dpctl
4344
{
@@ -325,6 +326,43 @@ void populate_add_dispatch_tables(void)
325326

326327
} // namespace impl
327328

329+
// SQRT
330+
namespace impl
331+
{
332+
333+
namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt;
334+
using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t;
335+
using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t;
336+
337+
static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types];
338+
static int sqrt_output_typeid_vector[td_ns::num_types];
339+
static sqrt_strided_impl_fn_ptr_t
340+
sqrt_strided_dispatch_vector[td_ns::num_types];
341+
342+
void populate_sqrt_dispatch_vectors(void)
343+
{
344+
using namespace td_ns;
345+
namespace fn_ns = sqrt_fn_ns;
346+
347+
using fn_ns::SqrtContigFactory;
348+
DispatchVectorBuilder<sqrt_contig_impl_fn_ptr_t, SqrtContigFactory,
349+
num_types>
350+
dvb1;
351+
dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector);
352+
353+
using fn_ns::SqrtStridedFactory;
354+
DispatchVectorBuilder<sqrt_strided_impl_fn_ptr_t, SqrtStridedFactory,
355+
num_types>
356+
dvb2;
357+
dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector);
358+
359+
using fn_ns::SqrtTypeMapFactory;
360+
DispatchVectorBuilder<int, SqrtTypeMapFactory, num_types> dvb3;
361+
dvb3.populate_dispatch_vector(sqrt_output_typeid_vector);
362+
}
363+
364+
} // namespace impl
365+
328366
namespace py = pybind11;
329367

330368
void init_elementwise_functions(py::module_ m)
@@ -628,7 +666,26 @@ void init_elementwise_functions(py::module_ m)
628666
// FIXME:
629667

630668
// U33: ==== SQRT (x)
631-
// FIXME:
669+
{
670+
impl::populate_sqrt_dispatch_vectors();
671+
using impl::sqrt_contig_dispatch_vector;
672+
using impl::sqrt_output_typeid_vector;
673+
using impl::sqrt_strided_dispatch_vector;
674+
675+
auto sqrt_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q,
676+
const event_vecT &depends = {}) {
677+
return py_unary_ufunc(
678+
src, dst, exec_q, depends, sqrt_output_typeid_vector,
679+
sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector);
680+
};
681+
m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"),
682+
py::arg("sycl_queue"), py::arg("depends") = py::list());
683+
684+
auto sqrt_result_type_pyapi = [&](py::dtype dtype) {
685+
return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector);
686+
};
687+
m.def("_sqrt_result_type", sqrt_result_type_pyapi);
688+
}
632689

633690
// B23: ==== SUBTRACT (x1, x2)
634691
// FIXME:

0 commit comments

Comments
 (0)