1
1
#pragma once
2
2
#include < CL/sycl.hpp>
3
3
4
+ #include " kernels/elementwise_functions/common.hpp"
5
+
4
6
#include " utils/offset_utils.hpp"
5
7
#include " utils/type_dispatch.hpp"
6
8
#include " utils/type_utils.hpp"
7
9
#include < pybind11/pybind11.h>
8
10
11
+ #include < iostream>
12
+
9
13
namespace dpctl
10
14
{
11
15
namespace tensor
@@ -18,120 +22,40 @@ namespace abs
18
22
namespace py = pybind11;
19
23
namespace td_ns = dpctl::tensor::type_dispatch;
20
24
21
- template <typename argT,
22
- typename resT = argT,
23
- unsigned int vec_sz = 4 ,
24
- unsigned int n_vecs = 2 >
25
- struct AbsContigFunctor
25
+ using dpctl::tensor::type_utils::is_complex;
26
+
27
+ template <typename argT, typename resT> struct AbsFunctor
26
28
{
27
- private:
28
- const argT *in = nullptr ;
29
- resT *out = nullptr ;
30
- const size_t nelems_;
31
29
32
- public:
33
- AbsContigFunctor ( const argT *inp, resT *res, const size_t n_elems)
34
- : in(inp), out(res), nelems_(n_elems)
35
- {
36
- }
30
+ using is_constant = typename std::false_type;
31
+ // constexpr resT constant_value = resT{};
32
+ using supports_vec = typename std::false_type;
33
+ using supports_sg_loadstore = typename std::negation<
34
+ std::disjunction<is_complex<resT>, is_complex<argT>>>;
37
35
38
- void operator ()(sycl::nd_item< 1 > ndit) const
36
+ resT operator ()(const argT &x)
39
37
{
40
- /* Each work-item processes vec_sz elements, contiguous in memory */
41
- /* NOTE: vec_sz must divide sg.max_local_range()[0] */
42
38
43
39
if constexpr (std::is_same_v<argT, bool > ||
44
40
(std::is_integral<argT>::value &&
45
41
std::is_unsigned<argT>::value))
46
42
{
47
43
static_assert (std::is_same_v<resT, argT>);
48
-
49
- auto sg = ndit.get_sub_group ();
50
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
51
- std::uint8_t max_sgSize = sg.get_max_local_range ()[0 ];
52
- size_t base = n_vecs * vec_sz *
53
- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
54
- sg.get_group_id ()[0 ] * max_sgSize);
55
-
56
- if (base + n_vecs * vec_sz * sgSize < nelems_ &&
57
- sgSize == max_sgSize) {
58
- using in_ptrT =
59
- sycl::multi_ptr<const argT,
60
- sycl::access::address_space::global_space>;
61
- using out_ptrT =
62
- sycl::multi_ptr<resT,
63
- sycl::access::address_space::global_space>;
64
- sycl::vec<argT, vec_sz> arg_vec;
65
-
66
- #pragma unroll
67
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
68
- arg_vec = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
69
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
70
- arg_vec);
71
- }
72
- }
73
- else {
74
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
75
- k += sgSize) {
76
- out[k] = in[k];
77
- }
78
- }
44
+ return x;
79
45
}
80
46
else {
81
- using dpctl::tensor::type_utils::is_complex;
82
- if constexpr (is_complex<argT>::value) {
83
- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
84
- size_t base = ndit.get_global_linear_id ();
85
-
86
- base = (base / sgSize) * sgSize * n_vecs * vec_sz +
87
- (base % sgSize);
88
- for (size_t offset = base;
89
- offset <
90
- std::min (nelems_, base + sgSize * (n_vecs * vec_sz));
91
- offset += sgSize)
92
- {
93
- out[offset] = std::abs (in[offset]);
94
- }
95
- }
96
- else {
97
- auto sg = ndit.get_sub_group ();
98
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
99
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
100
- size_t base = n_vecs * vec_sz *
101
- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
102
- sg.get_group_id ()[0 ] * maxsgSize);
103
-
104
- if (base + n_vecs * vec_sz < nelems_) {
105
- using in_ptrT = sycl::multi_ptr<
106
- const argT, sycl::access::address_space::global_space>;
107
- using out_ptrT = sycl::multi_ptr<
108
- resT, sycl::access::address_space::global_space>;
109
- sycl::vec<argT, vec_sz> arg_vec;
110
-
111
- #pragma unroll
112
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz;
113
- it += vec_sz) {
114
- arg_vec =
115
- sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
116
- #pragma unroll
117
- for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
118
- arg_vec[k] = std::abs (arg_vec[k]);
119
- }
120
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
121
- arg_vec);
122
- }
123
- }
124
- else {
125
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
126
- k += sgSize) {
127
- out[k] = std::abs (in[k]);
128
- }
129
- }
130
- }
47
+ return std::abs (x);
131
48
}
132
49
}
133
50
};
134
51
52
+ template <typename argT,
53
+ typename resT = argT,
54
+ unsigned int vec_sz = 4 ,
55
+ unsigned int n_vecs = 2 >
56
+ using AbsContigFunctor = elementwise_common::
57
+ UnaryContigFunctor<argT, resT, AbsFunctor<argT, resT>, vec_sz, n_vecs>;
58
+
135
59
template <typename T> struct AbsOutputType
136
60
{
137
61
using value_type = typename std::disjunction< // disjunction is C++17
@@ -220,39 +144,9 @@ template <typename fnT, typename T> struct AbsTypeMapFactory
220
144
}
221
145
};
222
146
223
- template <typename argT, typename resT, typename IndexerT>
224
- struct AbsStridedFunctor
225
- {
226
- private:
227
- const argT *in = nullptr ;
228
- resT *out = nullptr ;
229
- IndexerT inp_res_indexer_;
230
-
231
- public:
232
- AbsStridedFunctor (const argT *inp_p,
233
- resT *res_p,
234
- IndexerT two_offsets_indexer)
235
- : in(inp_p), out(res_p), inp_res_indexer_(two_offsets_indexer)
236
- {
237
- }
238
-
239
- void operator ()(sycl::id<1 > wid) const
240
- {
241
- auto offsets_ = inp_res_indexer_ (static_cast <py::ssize_t >(wid[0 ]));
242
- const auto &inp_offset = offsets_.get_first_offset ();
243
- const auto &out_offset = offsets_.get_second_offset ();
244
-
245
- if constexpr (std::is_same_v<argT, bool > ||
246
- (std::is_integral<argT>::value &&
247
- std::is_unsigned<argT>::value))
248
- {
249
- out[out_offset] = in[inp_offset];
250
- }
251
- else {
252
- out[out_offset] = std::abs (in[inp_offset]);
253
- }
254
- }
255
- };
147
+ template <typename argTy, typename resTy, typename IndexerT>
148
+ using AbsStridedFunctor = elementwise_common::
149
+ UnaryStridedFunctor<argTy, resTy, IndexerT, AbsFunctor<argTy, resTy>>;
256
150
257
151
template <typename T1, typename T2, typename T3> class abs_strided_kernel ;
258
152
0 commit comments