Skip to content

Allow different output type than input type when dispatching #1590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/abs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event abs_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AbsOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::abs(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/acos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event acos_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AcosOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::acos(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/acosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event acosh_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AcoshOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::acosh(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event add_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AddOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::add(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/asin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event asin_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AsinOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::asin(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/asinh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event asinh_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AsinhOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::asinh(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/atan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event atan_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AtanOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::atan(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/atan2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event atan2_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::Atan2OutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::atan2(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/atanh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event atanh_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::AtanhOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::atanh(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/ceil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event ceil_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::CeilOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::ceil(exec_q,
n, // number of elements to be calculated
Expand Down
34 changes: 7 additions & 27 deletions dpnp/backend/extensions/vm/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,8 @@ std::pair<sycl::event, sycl::event>
{
// check type_nums
int src_typenum = src.get_typenum();
int dst_typenum = dst.get_typenum();

auto array_types = dpctl_td_ns::usm_ndarray_types();
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

if (src_typeid != dst_typeid) {
throw py::value_error("Input and output arrays have different types.");
}

// check that queues are compatible
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
Expand Down Expand Up @@ -155,7 +148,7 @@ std::pair<sycl::event, sycl::event>
throw py::value_error("Input and outpur arrays must be C-contiguous.");
}

auto dispatch_fn = dispatch_vector[dst_typeid];
auto dispatch_fn = dispatch_vector[src_typeid];
if (dispatch_fn == nullptr) {
throw py::value_error("No implementation is defined for ufunc.");
}
Expand All @@ -179,16 +172,13 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
// check type_nums
int src1_typenum = src1.get_typenum();
int src2_typenum = src2.get_typenum();
int dst_typenum = dst.get_typenum();

auto array_types = dpctl_td_ns::usm_ndarray_types();
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
throw py::value_error(
"Either any of input arrays or output array have different types.");
if (src1_typeid != src2_typeid) {
throw py::value_error("Input arrays have different types.");
}

// check that queues are compatible
Expand Down Expand Up @@ -259,7 +249,7 @@ std::pair<sycl::event, sycl::event> binary_ufunc(
throw py::value_error("Input and outpur arrays must be C-contiguous.");
}

auto dispatch_fn = dispatch_vector[dst_typeid];
auto dispatch_fn = dispatch_vector[src1_typeid];
if (dispatch_fn == nullptr) {
throw py::value_error("No implementation is defined for ufunc.");
}
Expand All @@ -279,16 +269,8 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
{
// check type_nums
int src_typenum = src.get_typenum();
int dst_typenum = dst.get_typenum();

auto array_types = dpctl_td_ns::usm_ndarray_types();
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

// types must be the same
if (src_typeid != dst_typeid) {
return false;
}

// OneMKL VM functions perform a copy on host if no double type support
if (!exec_q.get_device().has(sycl::aspect::fp64)) {
Expand Down Expand Up @@ -356,7 +338,7 @@ bool need_to_call_unary_ufunc(sycl::queue exec_q,
}

// MKL function is not defined for the type
if (dispatch_vector[dst_typeid] == nullptr) {
if (dispatch_vector[src_typeid] == nullptr) {
return false;
}
return true;
Expand All @@ -372,15 +354,13 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
// check type_nums
int src1_typenum = src1.get_typenum();
int src2_typenum = src2.get_typenum();
int dst_typenum = dst.get_typenum();

auto array_types = dpctl_td_ns::usm_ndarray_types();
int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum);
int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum);
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

// types must be the same
if (src1_typeid != src2_typeid || src2_typeid != dst_typeid) {
if (src1_typeid != src2_typeid) {
return false;
}

Expand Down Expand Up @@ -454,7 +434,7 @@ bool need_to_call_binary_ufunc(sycl::queue exec_q,
}

// MKL function is not defined for the type
if (dispatch_vector[dst_typeid] == nullptr) {
if (dispatch_vector[src1_typeid] == nullptr) {
return false;
}
return true;
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/conj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event conj_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::ConjOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::conj(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/cos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event cos_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::CosOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::cos(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/cosh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event cosh_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::CoshOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::cosh(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/div.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event div_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::DivOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::div(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event exp_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::ExpOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::exp(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/expm1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event expm1_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::Expm1OutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::expm1(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/floor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event floor_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::FloorOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::floor(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/hypot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event hypot_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::HypotOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::hypot(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/ln.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event ln_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::LnOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::ln(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/log10.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event log10_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::Log10OutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::log10(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/log1p.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event log1p_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::Log1pOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::log1p(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/log2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event log2_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::Log2OutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::log2(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/mul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event mul_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::MulOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::mul(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ sycl::event pow_contig_impl(sycl::queue exec_q,

const T *a = reinterpret_cast<const T *>(in_a);
const T *b = reinterpret_cast<const T *>(in_b);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::PowOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::pow(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/round.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event round_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::RoundOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::rint(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/sin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event sin_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::SinOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::sin(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/sinh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event sinh_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::SinhOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::sinh(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/sqr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event sqr_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::SqrOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::sqr(exec_q,
n, // number of elements to be calculated
Expand Down
3 changes: 2 additions & 1 deletion dpnp/backend/extensions/vm/sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sycl::event sqrt_contig_impl(sycl::queue exec_q,
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);
using resTy = typename types::SqrtOutputType<T>::value_type;
resTy *y = reinterpret_cast<resTy *>(out_y);

return mkl_vm::sqrt(exec_q,
n, // number of elements to be calculated
Expand Down
Loading