Skip to content

Commit 9a1d709

Browse files
committed
add complex types to oneapi::mkl::vm::abs
1 parent a3fa354 commit 9a1d709

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

dpnp/backend/extensions/vm/abs.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ namespace ext
3838
{
3939
namespace vm
4040
{
41+
template <typename T>
42+
struct value_type_of
43+
{
44+
using type = T;
45+
};
46+
47+
template <typename T>
48+
struct value_type_of<std::complex<T>>
49+
{
50+
using type = typename std::complex<T>::value_type;
51+
};
52+
4153
template <typename T>
4254
sycl::event abs_contig_impl(sycl::queue exec_q,
4355
const std::int64_t n,
@@ -48,7 +60,8 @@ sycl::event abs_contig_impl(sycl::queue exec_q,
4860
type_utils::validate_type_for_device<T>(exec_q);
4961

5062
const T *a = reinterpret_cast<const T *>(in_a);
51-
T *y = reinterpret_cast<T *>(out_y);
63+
using OutT = typename value_type_of<T>::type;
64+
OutT *y = reinterpret_cast<OutT *>(out_y);
5265

5366
return mkl_vm::abs(exec_q,
5467
n, // number of elements to be calculated

dpnp/backend/extensions/vm/types_matrix.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ template <typename T>
5353
struct AbsOutputType
5454
{
5555
using value_type = typename std::disjunction<
56-
// TODO: Add complex type here, currently adding them here generates a
57-
// compile time error due to probably a bug in mkl
56+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<double>>,
57+
dpctl_td_ns::TypeMapResultEntry<T, std::complex<float>>,
5858
dpctl_td_ns::TypeMapResultEntry<T, double>,
5959
dpctl_td_ns::TypeMapResultEntry<T, float>,
6060
dpctl_td_ns::DefaultResultEntry<void>>::result_type;

0 commit comments

Comments
 (0)