Skip to content

Implements dpctl.tensor.sqrt #1205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
from dpctl.tensor._usmarray import usm_ndarray

from ._constants import e, inf, nan, newaxis, pi
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan, sqrt

__all__ = [
"Device",
Expand Down Expand Up @@ -171,4 +171,5 @@
"isinf",
"isnan",
"isfinite",
"sqrt",
]
10 changes: 10 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,13 @@
isinf = UnaryElementwiseFunc(
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
)

# SQRT

_sqrt_docstring_ = """
Computes sqrt for each element `x_i` for input array `x`.
"""

sqrt = UnaryElementwiseFunc(
"sqrt", ti._sqrt_result_type, ti._sqrt, _sqrt_docstring_
)
207 changes: 207 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#pragma once
#include <CL/sycl.hpp>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"
#include <pybind11/pybind11.h>

namespace dpctl
{
namespace tensor
{
namespace kernels
{
namespace sqrt
{

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;

template <typename argT, typename resT> struct SqrtFunctor
{

// is function constant for given argT
using is_constant = typename std::false_type;
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;

resT operator()(const argT &in)
{
return std::sqrt(in);
}
};

template <typename argTy,
typename resTy = argTy,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2>
using SqrtContigFunctor = elementwise_common::
UnaryContigFunctor<argTy, resTy, SqrtFunctor<argTy, resTy>, vec_sz, n_vecs>;

template <typename argTy, typename resTy, typename IndexerT>
using SqrtStridedFunctor = elementwise_common::
UnaryStridedFunctor<argTy, resTy, IndexerT, SqrtFunctor<argTy, resTy>>;

template <typename T> struct SqrtOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
td_ns::TypeMapEntry<T, sycl::half, sycl::half>,
td_ns::TypeMapEntry<T, float, float>,
td_ns::TypeMapEntry<T, double, double>,
td_ns::TypeMapEntry<T, std::complex<float>, std::complex<float>>,
td_ns::TypeMapEntry<T, std::complex<double>, std::complex<double>>,
td_ns::DefaultEntry<void>>::result_type;
};

typedef sycl::event (*sqrt_contig_impl_fn_ptr_t)(
sycl::queue,
size_t,
const char *,
char *,
const std::vector<sycl::event> &);

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
class sqrt_contig_kernel;

template <typename argTy>
sycl::event sqrt_contig_impl(sycl::queue exec_q,
size_t nelems,
const char *arg_p,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
sycl::event sqrt_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
constexpr size_t lws = 64;
constexpr unsigned int vec_sz = 4;
constexpr unsigned int n_vecs = 2;
static_assert(lws % vec_sz == 0);
auto gws_range = sycl::range<1>(
((nelems + n_vecs * lws * vec_sz - 1) / (lws * n_vecs * vec_sz)) *
lws);
auto lws_range = sycl::range<1>(lws);

using resTy = typename SqrtOutputType<argTy>::value_type;
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

cgh.parallel_for<
class sqrt_contig_kernel<argTy, resTy, vec_sz, n_vecs>>(
sycl::nd_range<1>(gws_range, lws_range),
SqrtContigFunctor<argTy, resTy, vec_sz, n_vecs>(arg_tp, res_tp,
nelems));
});
return sqrt_ev;
}

template <typename fnT, typename T> struct SqrtContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
void>) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = sqrt_contig_impl<T>;
return fn;
}
}
};

template <typename fnT, typename T> struct SqrtTypeMapFactory
{
/*! @brief get typeid for output type of std::sqrt(T x) */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
using rT = typename SqrtOutputType<T>::value_type;
;
return td_ns::GetTypeid<rT>{}.get();
}
};

template <typename T1, typename T2, typename T3> class sqrt_strided_kernel;

typedef sycl::event (*sqrt_strided_impl_fn_ptr_t)(
sycl::queue,
size_t,
int,
const py::ssize_t *,
const char *,
py::ssize_t,
char *,
py::ssize_t,
const std::vector<sycl::event> &,
const std::vector<sycl::event> &);

template <typename argTy>
sycl::event
sqrt_strided_impl(sycl::queue exec_q,
size_t nelems,
int nd,
const py::ssize_t *shape_and_strides,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.depends_on(additional_depends);

using resTy = typename SqrtOutputType<argTy>::value_type;
using IndexerT =
typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;

IndexerT arg_res_indexer(nd, arg_offset, res_offset, shape_and_strides);

const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_p);
resTy *res_tp = reinterpret_cast<resTy *>(res_p);

sycl::range<1> gRange{nelems};

cgh.parallel_for<sqrt_strided_kernel<argTy, resTy, IndexerT>>(
gRange, SqrtStridedFunctor<argTy, resTy, IndexerT>(
arg_tp, res_tp, arg_res_indexer));
});
return comp_ev;
}

template <typename fnT, typename T> struct SqrtStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename SqrtOutputType<T>::value_type,
void>) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = sqrt_strided_impl<T>;
return fn;
}
}
};

} // namespace sqrt
} // namespace kernels
} // namespace tensor
} // namespace dpctl
59 changes: 58 additions & 1 deletion dpctl/tensor/libtensor/source/elementwise_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "kernels/elementwise_functions/isfinite.hpp"
#include "kernels/elementwise_functions/isinf.hpp"
#include "kernels/elementwise_functions/isnan.hpp"
#include "kernels/elementwise_functions/sqrt.hpp"

namespace dpctl
{
Expand Down Expand Up @@ -325,6 +326,43 @@ void populate_add_dispatch_tables(void)

} // namespace impl

// SQRT
namespace impl
{

namespace sqrt_fn_ns = dpctl::tensor::kernels::sqrt;
using sqrt_fn_ns::sqrt_contig_impl_fn_ptr_t;
using sqrt_fn_ns::sqrt_strided_impl_fn_ptr_t;

static sqrt_contig_impl_fn_ptr_t sqrt_contig_dispatch_vector[td_ns::num_types];
static int sqrt_output_typeid_vector[td_ns::num_types];
static sqrt_strided_impl_fn_ptr_t
sqrt_strided_dispatch_vector[td_ns::num_types];

void populate_sqrt_dispatch_vectors(void)
{
using namespace td_ns;
namespace fn_ns = sqrt_fn_ns;

using fn_ns::SqrtContigFactory;
DispatchVectorBuilder<sqrt_contig_impl_fn_ptr_t, SqrtContigFactory,
num_types>
dvb1;
dvb1.populate_dispatch_vector(sqrt_contig_dispatch_vector);

using fn_ns::SqrtStridedFactory;
DispatchVectorBuilder<sqrt_strided_impl_fn_ptr_t, SqrtStridedFactory,
num_types>
dvb2;
dvb2.populate_dispatch_vector(sqrt_strided_dispatch_vector);

using fn_ns::SqrtTypeMapFactory;
DispatchVectorBuilder<int, SqrtTypeMapFactory, num_types> dvb3;
dvb3.populate_dispatch_vector(sqrt_output_typeid_vector);
}

} // namespace impl

namespace py = pybind11;

void init_elementwise_functions(py::module_ m)
Expand Down Expand Up @@ -628,7 +666,26 @@ void init_elementwise_functions(py::module_ m)
// FIXME:

// U33: ==== SQRT (x)
// FIXME:
{
impl::populate_sqrt_dispatch_vectors();
using impl::sqrt_contig_dispatch_vector;
using impl::sqrt_output_typeid_vector;
using impl::sqrt_strided_dispatch_vector;

auto sqrt_pyapi = [&](arrayT src, arrayT dst, sycl::queue exec_q,
const event_vecT &depends = {}) {
return py_unary_ufunc(
src, dst, exec_q, depends, sqrt_output_typeid_vector,
sqrt_contig_dispatch_vector, sqrt_strided_dispatch_vector);
};
m.def("_sqrt", sqrt_pyapi, "", py::arg("src"), py::arg("dst"),
py::arg("sycl_queue"), py::arg("depends") = py::list());

auto sqrt_result_type_pyapi = [&](py::dtype dtype) {
return py_unary_ufunc_result_type(dtype, sqrt_output_typeid_vector);
};
m.def("_sqrt_result_type", sqrt_result_type_pyapi);
}

// B23: ==== SUBTRACT (x1, x2)
// FIXME:
Expand Down
Loading