Skip to content

Commit cb75d56

Browse files
committed
Refactor sycl_complex indirect include
* Move sycl_complex.hpp to utils * No longer use exprm_ns defined by header, define on per-file basis * Include alias to type sycl_complex_t<T> under sycl_utils namespace * Use identical include macro where inclusion of sycl_complex would be impossible
1 parent 8902414 commit cb75d56

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+235
-127
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acos.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace acos
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -72,7 +74,7 @@ template <typename argT, typename resT> struct AcosFunctor
7274
using realT = typename argT::value_type;
7375

7476
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75-
using sycl_complexT = exprm_ns::complex<realT>;
77+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
7678
sycl_complexT z = sycl_complexT(in);
7779
const realT x = exprm_ns::real(z);
7880
const realT y = exprm_ns::imag(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/acosh.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace acosh
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -77,7 +79,7 @@ template <typename argT, typename resT> struct AcoshFunctor
7779
* where the sign is chosen so Re(acosh(in)) >= 0.
7880
* So, we first calculate acos(in) and then acosh(in).
7981
*/
80-
using sycl_complexT = exprm_ns::complex<realT>;
82+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
8183
sycl_complexT z = sycl_complexT(in);
8284
const realT x = exprm_ns::real(z);
8385
const realT y = exprm_ns::imag(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

+22-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "utils/offset_utils.hpp"
@@ -50,8 +50,10 @@ namespace add
5050
{
5151

5252
using dpctl::tensor::ssize_t;
53+
namespace su_ns = dpctl::tensor::sycl_utils;
5354
namespace td_ns = dpctl::tensor::type_dispatch;
5455
namespace tu_ns = dpctl::tensor::type_utils;
56+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5557

5658
template <typename argT1, typename argT2, typename resT> struct AddFunctor
5759
{
@@ -69,21 +71,22 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
6971
using rT1 = typename argT1::value_type;
7072
using rT2 = typename argT2::value_type;
7173

72-
return exprm_ns::complex<rT1>(in1) + exprm_ns::complex<rT2>(in2);
74+
return su_ns::sycl_complex_t<rT1>(in1) +
75+
su_ns::sycl_complex_t<rT2>(in2);
7376
}
7477
else if constexpr (tu_ns::is_complex<argT1>::value &&
7578
!tu_ns::is_complex<argT2>::value)
7679
{
7780
using rT1 = typename argT1::value_type;
7881

79-
return exprm_ns::complex<rT1>(in1) + in2;
82+
return su_ns::sycl_complex_t<rT1>(in1) + in2;
8083
}
8184
else if constexpr (!tu_ns::is_complex<argT1>::value &&
8285
tu_ns::is_complex<argT2>::value)
8386
{
8487
using rT2 = typename argT2::value_type;
8588

86-
return in1 + exprm_ns::complex<rT2>(in2);
89+
return in1 + su_ns::sycl_complex_t<rT2>(in2);
8790
}
8891
else {
8992
return in1 + in2;
@@ -460,7 +463,21 @@ template <typename argT, typename resT> struct AddInplaceFunctor
460463
using supports_vec = std::negation<
461464
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
462465

463-
void operator()(resT &res, const argT &in) { res += in; }
466+
void operator()(resT &res, const argT &in)
467+
{
468+
if constexpr (tu_ns::is_complex_v<resT> && tu_ns::is_complex_v<argT>) {
469+
using rT1 = typename resT::value_type;
470+
using rT2 = typename argT::value_type;
471+
472+
auto tmp = su_ns::sycl_complex_t<rT1>(res);
473+
tmp += su_ns::sycl_complex_t<rT2>(in);
474+
475+
res = resT(tmp);
476+
}
477+
else {
478+
res += in;
479+
}
480+
}
464481

465482
template <int vec_sz>
466483
void operator()(sycl::vec<resT, vec_sz> &res,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

33-
#include "sycl_complex.hpp"
33+
#include "utils/sycl_complex.hpp"
3434
#include "vec_size_util.hpp"
3535

3636
#include "kernels/dpctl_tensor_types.hpp"
@@ -50,7 +50,9 @@ namespace angle
5050
{
5151

5252
using dpctl::tensor::ssize_t;
53+
namespace su_ns = dpctl::tensor::sycl_utils;
5354
namespace td_ns = dpctl::tensor::type_dispatch;
55+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5456

5557
using dpctl::tensor::type_utils::is_complex;
5658

@@ -71,7 +73,7 @@ template <typename argT, typename resT> struct AngleFunctor
7173
{
7274
using rT = typename argT::value_type;
7375

74-
return exprm_ns::arg(exprm_ns::complex<rT>(in)); // arg(in);
76+
return exprm_ns::arg(su_ns::sycl_complex_t<rT>(in)); // arg(in);
7577
}
7678
};
7779

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asin.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace asin
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -80,7 +82,7 @@ template <typename argT, typename resT> struct AsinFunctor
8082
* y = imag(I * conj(in)) = real(in)
8183
* and then return {imag(w), real(w)} which is asin(in)
8284
*/
83-
using sycl_complexT = exprm_ns::complex<realT>;
85+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
8486
sycl_complexT z = sycl_complexT(in);
8587
const realT x = exprm_ns::imag(z);
8688
const realT y = exprm_ns::real(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/asinh.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace asinh
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -72,7 +74,7 @@ template <typename argT, typename resT> struct AsinhFunctor
7274
using realT = typename argT::value_type;
7375

7476
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75-
using sycl_complexT = exprm_ns::complex<realT>;
77+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
7678
sycl_complexT z = sycl_complexT(in);
7779
const realT x = exprm_ns::real(z);
7880
const realT y = exprm_ns::imag(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atan.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

33-
#include "sycl_complex.hpp"
33+
#include "utils/sycl_complex.hpp"
3434
#include "vec_size_util.hpp"
3535

3636
#include "kernels/dpctl_tensor_types.hpp"
@@ -50,7 +50,9 @@ namespace atan
5050
{
5151

5252
using dpctl::tensor::ssize_t;
53+
namespace su_ns = dpctl::tensor::sycl_utils;
5354
namespace td_ns = dpctl::tensor::type_dispatch;
55+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5456

5557
using dpctl::tensor::kernels::vec_size_utils::ContigHyperparameterSetDefault;
5658
using dpctl::tensor::kernels::vec_size_utils::UnaryContigHyperparameterSetEntry;
@@ -83,7 +85,7 @@ template <typename argT, typename resT> struct AtanFunctor
8385
* y = imag(I * conj(in)) = real(in)
8486
* and then return {imag(w), real(w)} which is atan(in)
8587
*/
86-
using sycl_complexT = exprm_ns::complex<realT>;
88+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
8789
sycl_complexT z = sycl_complexT(in);
8890
const realT x = exprm_ns::imag(z);
8991
const realT y = exprm_ns::real(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/atanh.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include <sycl/sycl.hpp>
3131
#include <type_traits>
3232

33-
#include "sycl_complex.hpp"
33+
#include "utils/sycl_complex.hpp"
3434
#include "vec_size_util.hpp"
3535

3636
#include "kernels/dpctl_tensor_types.hpp"
@@ -50,7 +50,9 @@ namespace atanh
5050
{
5151

5252
using dpctl::tensor::ssize_t;
53+
namespace su_ns = dpctl::tensor::sycl_utils;
5354
namespace td_ns = dpctl::tensor::type_dispatch;
55+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5456

5557
using dpctl::tensor::type_utils::is_complex;
5658

@@ -73,7 +75,7 @@ template <typename argT, typename resT> struct AtanhFunctor
7375
using realT = typename argT::value_type;
7476
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
7577

76-
using sycl_complexT = exprm_ns::complex<realT>;
78+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
7779
sycl_complexT z = sycl_complexT(in);
7880
const realT x = exprm_ns::real(z);
7981
const realT y = exprm_ns::imag(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cabs_impl.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include <complex>
2828
#include <limits>
2929

30-
#include "sycl_complex.hpp"
30+
#include "utils/sycl_complex.hpp"
3131

3232
namespace dpctl
3333
{
@@ -38,6 +38,9 @@ namespace kernels
3838
namespace detail
3939
{
4040

41+
namespace su_ns = dpctl::tensor::sycl_utils;
42+
namespace exprm_ns = sycl::ext::oneapi::experimental;
43+
4144
template <typename realT> realT cabs(std::complex<realT> const &z)
4245
{
4346
// Special values for cabs( x + y * 1j):
@@ -51,8 +54,8 @@ template <typename realT> realT cabs(std::complex<realT> const &z)
5154
// * If x is a finite number and y is NaN, the result is NaN.
5255
// * If x is NaN and y is NaN, the result is NaN.
5356

54-
using sycl_complexT = exprm_ns::complex<realT>;
55-
sycl_complexT _z = exprm_ns::complex<realT>(z);
57+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
58+
sycl_complexT _z = su_ns::sycl_complex_t<realT>(z);
5659
const realT x = exprm_ns::real(_z);
5760
const realT y = exprm_ns::imag(_z);
5861

dpctl/tensor/libtensor/include/kernels/elementwise_functions/conj.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include <sycl/sycl.hpp>
3232
#include <type_traits>
3333

34-
#include "sycl_complex.hpp"
34+
#include "utils/sycl_complex.hpp"
3535
#include "vec_size_util.hpp"
3636

3737
#include "kernels/dpctl_tensor_types.hpp"
@@ -51,7 +51,9 @@ namespace conj
5151
{
5252

5353
using dpctl::tensor::ssize_t;
54+
namespace su_ns = dpctl::tensor::sycl_utils;
5455
namespace td_ns = dpctl::tensor::type_dispatch;
56+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5557

5658
using dpctl::tensor::type_utils::is_complex;
5759

@@ -73,7 +75,7 @@ template <typename argT, typename resT> struct ConjFunctor
7375
if constexpr (is_complex<argT>::value) {
7476
using rT = typename argT::value_type;
7577

76-
return exprm_ns::conj(exprm_ns::complex<rT>(in)); // conj(in);
78+
return exprm_ns::conj(su_ns::sycl_complex_t<rT>(in)); // conj(in);
7779
}
7880
else {
7981
if constexpr (!std::is_same_v<argT, bool>)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace cos
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -72,7 +74,7 @@ template <typename argT, typename resT> struct CosFunctor
7274
using realT = typename argT::value_type;
7375

7476
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
75-
using sycl_complexT = exprm_ns::complex<realT>;
77+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
7678
sycl_complexT z = sycl_complexT(in);
7779
const realT z_re = exprm_ns::real(z);
7880
const realT z_im = exprm_ns::imag(z);

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cosh.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
#include <sycl/sycl.hpp>
3030
#include <type_traits>
3131

32-
#include "sycl_complex.hpp"
32+
#include "utils/sycl_complex.hpp"
3333
#include "vec_size_util.hpp"
3434

3535
#include "kernels/dpctl_tensor_types.hpp"
@@ -49,7 +49,9 @@ namespace cosh
4949
{
5050

5151
using dpctl::tensor::ssize_t;
52+
namespace su_ns = dpctl::tensor::sycl_utils;
5253
namespace td_ns = dpctl::tensor::type_dispatch;
54+
namespace exprm_ns = sycl::ext::oneapi::experimental;
5355

5456
using dpctl::tensor::type_utils::is_complex;
5557

@@ -73,7 +75,7 @@ template <typename argT, typename resT> struct CoshFunctor
7375

7476
constexpr realT q_nan = std::numeric_limits<realT>::quiet_NaN();
7577

76-
using sycl_complexT = exprm_ns::complex<realT>;
78+
using sycl_complexT = su_ns::sycl_complex_t<realT>;
7779
sycl_complexT z = sycl_complexT(in);
7880
const realT x = exprm_ns::real(z);
7981
const realT y = exprm_ns::imag(z);

0 commit comments

Comments
 (0)