29
29
#include < sycl/sycl.hpp>
30
30
#include < type_traits>
31
31
32
- #include " sycl_complex.hpp"
32
+ #include " utils/ sycl_complex.hpp"
33
33
#include " vec_size_util.hpp"
34
34
35
35
#include " utils/offset_utils.hpp"
@@ -50,8 +50,10 @@ namespace add
50
50
{
51
51
52
52
using dpctl::tensor::ssize_t ;
53
+ namespace su_ns = dpctl::tensor::sycl_utils;
53
54
namespace td_ns = dpctl::tensor::type_dispatch;
54
55
namespace tu_ns = dpctl::tensor::type_utils;
56
+ namespace exprm_ns = sycl::ext::oneapi::experimental;
55
57
56
58
template <typename argT1, typename argT2, typename resT> struct AddFunctor
57
59
{
@@ -69,21 +71,22 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
69
71
using rT1 = typename argT1::value_type;
70
72
using rT2 = typename argT2::value_type;
71
73
72
- return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
74
+ return su_ns::sycl_complex_t <rT1>(in1) +
75
+ su_ns::sycl_complex_t <rT2>(in2);
73
76
}
74
77
else if constexpr (tu_ns::is_complex<argT1>::value &&
75
78
!tu_ns::is_complex<argT2>::value)
76
79
{
77
80
using rT1 = typename argT1::value_type;
78
81
79
- return exprm_ns::complex <rT1>(in1) + in2;
82
+ return su_ns:: sycl_complex_t <rT1>(in1) + in2;
80
83
}
81
84
else if constexpr (!tu_ns::is_complex<argT1>::value &&
82
85
tu_ns::is_complex<argT2>::value)
83
86
{
84
87
using rT2 = typename argT2::value_type;
85
88
86
- return in1 + exprm_ns::complex <rT2>(in2);
89
+ return in1 + su_ns:: sycl_complex_t <rT2>(in2);
87
90
}
88
91
else {
89
92
return in1 + in2;
@@ -460,7 +463,21 @@ template <typename argT, typename resT> struct AddInplaceFunctor
460
463
using supports_vec = std::negation<
461
464
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
462
465
463
- void operator ()(resT &res, const argT &in) { res += in; }
466
+ void operator ()(resT &res, const argT &in)
467
+ {
468
+ if constexpr (tu_ns::is_complex_v<resT> && tu_ns::is_complex_v<argT>) {
469
+ using rT1 = typename resT::value_type;
470
+ using rT2 = typename argT::value_type;
471
+
472
+ auto tmp = su_ns::sycl_complex_t <rT1>(res);
473
+ tmp += su_ns::sycl_complex_t <rT2>(in);
474
+
475
+ res = resT (tmp);
476
+ }
477
+ else {
478
+ res += in;
479
+ }
480
+ }
464
481
465
482
template <int vec_sz>
466
483
void operator ()(sycl::vec<resT, vec_sz> &res,
0 commit comments