Skip to content

Commit

Permalink
Improve dpnp.cos() and dpnp.sin() implementations (#1471)
Browse files Browse the repository at this point in the history
* Improve dpnp.cos() and dpnp.sin() implementations

* Update dpnp/backend/extensions/vm/vm_py.cpp

Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com>

* Update dpnp/backend/extensions/vm/vm_py.cpp

Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com>

---------

Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com>
  • Loading branch information
antonwolfy and vlad-perevezentsev authored Jul 7, 2023
1 parent f49da0e commit f82cdc4
Show file tree
Hide file tree
Showing 15 changed files with 580 additions and 105 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/math_tests/test_explog.py
third_party/cupy/math_tests/test_trigonometric.py
third_party/cupy/sorting_tests/test_sort.py
VER_JSON_NAME: 'version.json'
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "
Expand Down
78 changes: 78 additions & 0 deletions dpnp/backend/extensions/vm/cos.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>

#include "common.hpp"
#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace vm
{
template <typename T>
sycl::event cos_contig_impl(sycl::queue exec_q,
const std::int64_t n,
const char *in_a,
char *out_y,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);

return mkl_vm::cos(exec_q,
n, // number of elements to be calculated
a, // pointer `a` containing input vector of size n
y, // pointer `y` to the output vector of size n
depends);
}

template <typename fnT, typename T>
struct CosContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename types::CosOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return cos_contig_impl<T>;
}
}
};
} // namespace vm
} // namespace ext
} // namespace backend
} // namespace dpnp
78 changes: 78 additions & 0 deletions dpnp/backend/extensions/vm/sin.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>

#include "common.hpp"
#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace vm
{
template <typename T>
sycl::event sin_contig_impl(sycl::queue exec_q,
const std::int64_t n,
const char *in_a,
char *out_y,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);

return mkl_vm::sin(exec_q,
n, // number of elements to be calculated
a, // pointer `a` containing input vector of size n
y, // pointer `y` to the output vector of size n
depends);
}

template <typename fnT, typename T>
struct SinContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename types::SinOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return sin_contig_impl<T>;
}
}
};
} // namespace vm
} // namespace ext
} // namespace backend
} // namespace dpnp
38 changes: 38 additions & 0 deletions dpnp/backend/extensions/vm/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ struct DivOutputType
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct CosOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::ln<T> function.
Expand All @@ -86,6 +105,25 @@ struct LnOutputType
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct SinOutputType
{
using value_type = typename std::disjunction<
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
dpctl_td_ns::
TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
};
} // namespace types
} // namespace vm
} // namespace ext
Expand Down
60 changes: 60 additions & 0 deletions dpnp/backend/extensions/vm/vm_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
#include <pybind11/stl.h>

#include "common.hpp"
#include "cos.hpp"
#include "div.hpp"
#include "ln.hpp"
#include "sin.hpp"
#include "types_matrix.hpp"

namespace py = pybind11;
Expand All @@ -43,7 +45,9 @@ using vm_ext::unary_impl_fn_ptr_t;

static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];

static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];

PYBIND11_MODULE(_vm_impl, m)
{
Expand Down Expand Up @@ -80,6 +84,34 @@ PYBIND11_MODULE(_vm_impl, m)
py::arg("dst"));
}

// UnaryUfunc: ==== Cos(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
vm_ext::CosContigFactory>(
cos_dispatch_vector);

auto cos_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
const event_vecT &depends = {}) {
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
cos_dispatch_vector);
};
m.def("_cos", cos_pyapi,
"Call `cos` function from OneMKL VM library to compute "
"cosine of vector elements",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
py::arg("depends") = py::list());

auto cos_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
arrayT dst) {
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
cos_dispatch_vector);
};
m.def("_mkl_cos_to_call", cos_need_to_call_pyapi,
"Check input arguments to answer if `cos` function from "
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
}

// UnaryUfunc: ==== Ln(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
Expand Down Expand Up @@ -107,4 +139,32 @@ PYBIND11_MODULE(_vm_impl, m)
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
}

// UnaryUfunc: ==== Sin(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
vm_ext::SinContigFactory>(
sin_dispatch_vector);

auto sin_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
const event_vecT &depends = {}) {
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
sin_dispatch_vector);
};
m.def("_sin", sin_pyapi,
"Call `sin` function from OneMKL VM library to compute "
"sine of vector elements",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
py::arg("depends") = py::list());

auto sin_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
arrayT dst) {
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
sin_dispatch_vector);
};
m.def("_mkl_sin_to_call", sin_need_to_call_pyapi,
"Check input arguments to answer if `sin` function from "
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
}
}
6 changes: 2 additions & 4 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_CORRELATE_EXT, /**< Used in numpy.correlate() impl, requires extra
parameters */
DPNP_FN_COS, /**< Used in numpy.cos() impl */
DPNP_FN_COS_EXT, /**< Used in numpy.cos() impl, requires extra parameters */
DPNP_FN_COSH, /**< Used in numpy.cosh() impl */
DPNP_FN_COSH, /**< Used in numpy.cosh() impl */
DPNP_FN_COSH_EXT, /**< Used in numpy.cosh() impl, requires extra parameters
*/
DPNP_FN_COUNT_NONZERO, /**< Used in numpy.count_nonzero() impl */
Expand Down Expand Up @@ -475,8 +474,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_SIGN_EXT, /**< Used in numpy.sign() impl, requires extra parameters
*/
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SIN_EXT, /**< Used in numpy.sin() impl, requires extra parameters */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SINH_EXT, /**< Used in numpy.sinh() impl, requires extra parameters
*/
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
Expand Down
18 changes: 0 additions & 18 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,6 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_COS][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_cos_c_default<double, double>};

fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_cos_c_ext<int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_cos_c_ext<int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_cos_c_ext<float, float>};
fmap[DPNPFuncName::DPNP_FN_COS_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_cos_c_ext<double, double>};

fmap[DPNPFuncName::DPNP_FN_COSH][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_cosh_c_default<int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_COSH][eft_LNG][eft_LNG] = {
Expand Down Expand Up @@ -711,15 +702,6 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_SIN][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_sin_c_default<double, double>};

fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_sin_c_ext<int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_LNG][eft_LNG] = {
eft_DBL, (void *)dpnp_sin_c_ext<int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_sin_c_ext<float, float>};
fmap[DPNPFuncName::DPNP_FN_SIN_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_sin_c_ext<double, double>};

fmap[DPNPFuncName::DPNP_FN_SINH][eft_INT][eft_INT] = {
eft_DBL, (void *)dpnp_sinh_c_default<int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_SINH][eft_LNG][eft_LNG] = {
Expand Down
6 changes: 0 additions & 6 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_COPYTO_EXT
DPNP_FN_CORRELATE
DPNP_FN_CORRELATE_EXT
DPNP_FN_COS
DPNP_FN_COS_EXT
DPNP_FN_COSH
DPNP_FN_COSH_EXT
DPNP_FN_COUNT_NONZERO
Expand Down Expand Up @@ -293,8 +291,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_SEARCHSORTED_EXT
DPNP_FN_SIGN
DPNP_FN_SIGN_EXT
DPNP_FN_SIN
DPNP_FN_SIN_EXT
DPNP_FN_SINH
DPNP_FN_SINH_EXT
DPNP_FN_SORT
Expand Down Expand Up @@ -546,7 +542,6 @@ cpdef dpnp_descriptor dpnp_arcsinh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_arctan(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_arctanh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_cbrt(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_cos(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_cosh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_degrees(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_exp(dpnp_descriptor array1, dpnp_descriptor out)
Expand All @@ -557,7 +552,6 @@ cpdef dpnp_descriptor dpnp_log1p(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sin(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1)
cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out)
cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1)
Expand Down
Loading

0 comments on commit f82cdc4

Please sign in to comment.