7
7
#include " utils/offset_utils.hpp"
8
8
#include " utils/type_dispatch.hpp"
9
9
#include " utils/type_utils.hpp"
10
+
11
+ #include " kernels/elementwise_functions/common.hpp"
10
12
#include < pybind11/pybind11.h>
11
13
12
14
namespace dpctl
@@ -20,101 +22,60 @@ namespace add
20
22
21
23
namespace py = pybind11;
22
24
namespace td_ns = dpctl::tensor::type_dispatch;
25
+ namespace tu_ns = dpctl::tensor::type_utils;
23
26
24
- template <typename argT1,
25
- typename argT2,
26
- typename resT,
27
- unsigned int vec_sz = 4 ,
28
- unsigned int n_vecs = 2 >
29
- struct AddContigFunctor
27
+ template <typename argT1, typename argT2, typename resT> struct AddFunctor
30
28
{
31
- private:
32
- const argT1 *in1 = nullptr ;
33
- const argT2 *in2 = nullptr ;
34
- resT *out = nullptr ;
35
- const size_t nelems_;
36
-
37
- public:
38
- AddContigFunctor (const argT1 *inp1,
39
- const argT2 *inp2,
40
- resT *res,
41
- const size_t n_elems)
42
- : in1(inp1), in2(inp2), out(res), nelems_(n_elems)
29
+
30
+ using supports_sg_loadstore = std::negation<
31
+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
32
+ using supports_vec = std::negation<
33
+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
34
+
35
+ resT operator ()(const argT1 &in1, const argT2 &in2)
43
36
{
37
+ return in1 + in2;
44
38
}
45
39
46
- void operator ()(sycl::nd_item<1 > ndit) const
40
+ template <int vec_sz>
41
+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT1, vec_sz> &in1,
42
+ const sycl::vec<argT2, vec_sz> &in2)
47
43
{
48
- /* Each work-item processes vec_sz elements, contiguous in memory */
49
-
50
- using dpctl::tensor::type_utils::is_complex;
51
- if constexpr (is_complex<argT1>::value || is_complex<argT2>::value) {
52
- std::uint8_t sgSize = ndit.get_sub_group ().get_local_range ()[0 ];
53
- size_t base = ndit.get_global_linear_id ();
54
-
55
- base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
56
- for (size_t offset = base;
57
- offset < std::min (nelems_, base + sgSize * (n_vecs * vec_sz));
58
- offset += sgSize)
59
- {
60
- out[offset] = in1[offset] + in2[offset];
61
- }
44
+ auto tmp = in1 + in2;
45
+ if constexpr (std::is_same_v<resT,
46
+ typename decltype (tmp)::element_type>) {
47
+ return tmp;
62
48
}
63
49
else {
64
- auto sg = ndit.get_sub_group ();
65
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
66
- std::uint8_t maxsgSize = sg.get_max_local_range ()[0 ];
67
- size_t base = n_vecs * vec_sz *
68
- (ndit.get_group (0 ) * ndit.get_local_range (0 ) +
69
- sg.get_group_id ()[0 ] * maxsgSize);
50
+ using dpctl::tensor::type_utils::vec_cast;
70
51
71
- if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
72
- (sgSize == maxsgSize)) {
73
- using in_ptrT1 =
74
- sycl::multi_ptr<const argT1,
75
- sycl::access::address_space::global_space>;
76
- using in_ptrT2 =
77
- sycl::multi_ptr<const argT2,
78
- sycl::access::address_space::global_space>;
79
- using out_ptrT =
80
- sycl::multi_ptr<resT,
81
- sycl::access::address_space::global_space>;
82
- sycl::vec<argT1, vec_sz> arg1_vec;
83
- sycl::vec<argT2, vec_sz> arg2_vec;
84
- sycl::vec<resT, vec_sz> res_vec;
85
-
86
- #pragma unroll
87
- for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
88
- arg1_vec =
89
- sg.load <vec_sz>(in_ptrT1 (&in1[base + it * sgSize]));
90
- arg2_vec =
91
- sg.load <vec_sz>(in_ptrT2 (&in2[base + it * sgSize]));
92
- if constexpr (std::is_same_v<argT1, resT> &&
93
- std::is_same_v<argT2, resT>) {
94
- res_vec = arg1_vec + arg2_vec;
95
- }
96
- else {
97
- using dpctl::tensor::type_utils::vec_cast;
98
-
99
- auto tmp = arg1_vec + arg2_vec;
100
- res_vec = std::move (
101
- vec_cast<resT, typename decltype (tmp)::element_type,
102
- vec_sz>(tmp));
103
- }
104
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
105
- res_vec);
106
- }
107
- }
108
- else {
109
- for (size_t k = base + sg.get_local_id ()[0 ]; k < nelems_;
110
- k += sgSize) {
111
- out[k] = in1[k] + in2[k];
112
- }
113
- }
52
+ return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
53
+ tmp);
114
54
}
115
55
}
116
56
};
117
57
58
+ template <typename argT1,
59
+ typename argT2,
60
+ typename resT,
61
+ unsigned int vec_sz = 4 ,
62
+ unsigned int n_vecs = 2 >
63
+ using AddContigFunctor =
64
+ elementwise_common::BinaryContigFunctor<argT1,
65
+ argT2,
66
+ resT,
67
+ AddFunctor<argT1, argT2, resT>,
68
+ vec_sz,
69
+ n_vecs>;
70
+
71
+ template <typename argT1, typename argT2, typename resT, typename IndexerT>
72
+ using AddStridedFunctor =
73
+ elementwise_common::BinaryStridedFunctor<argT1,
74
+ argT2,
75
+ resT,
76
+ IndexerT,
77
+ AddFunctor<argT1, argT2, resT>>;
78
+
118
79
template <typename T1, typename T2> struct AddOutputType
119
80
{
120
81
using value_type = typename std::disjunction< // disjunction is C++17
@@ -257,41 +218,6 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
257
218
}
258
219
};
259
220
260
- template <typename argT1,
261
- typename argT2,
262
- typename resT,
263
- typename ThreeOffsets_IndexerT>
264
- struct AddStridedFunctor
265
- {
266
- private:
267
- const argT1 *in1 = nullptr ;
268
- const argT2 *in2 = nullptr ;
269
- resT *out = nullptr ;
270
- ThreeOffsets_IndexerT three_offsets_indexer_;
271
-
272
- public:
273
- AddStridedFunctor (const argT1 *inp1_tp,
274
- const argT2 *inp2_tp,
275
- resT *res_tp,
276
- ThreeOffsets_IndexerT inps_res_indexer)
277
- : in1(inp1_tp), in2(inp2_tp), out(res_tp),
278
- three_offsets_indexer_ (inps_res_indexer)
279
- {
280
- }
281
-
282
- void operator ()(sycl::id<1 > wid) const
283
- {
284
- const auto &three_offsets_ =
285
- three_offsets_indexer_ (static_cast <py::ssize_t >(wid.get (0 )));
286
-
287
- const auto &inp1_offset = three_offsets_.get_first_offset ();
288
- const auto &inp2_offset = three_offsets_.get_second_offset ();
289
- const auto &out_offset = three_offsets_.get_third_offset ();
290
-
291
- out[out_offset] = in1[inp1_offset] + in2[inp2_offset];
292
- }
293
- };
294
-
295
221
template <typename T1, typename T2, typename resT, typename IndexerT>
296
222
class add_strided_strided_kernel ;
297
223
@@ -435,40 +361,41 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
435
361
size_t n_groups = (n_elems + lws - 1 ) / lws;
436
362
auto gwsRange = sycl::range<1 >(n_groups * lws);
437
363
438
- cgh.parallel_for <class add_matrix_vector_broadcast_sg_krn <argT1, argT2, resT>>(
439
- sycl::nd_range<1 >(gwsRange, lwsRange),
440
- [=](sycl::nd_item<1 > ndit)
441
- {
442
- auto sg = ndit.get_sub_group ();
443
- size_t gid = ndit.get_global_linear_id ();
364
+ cgh.parallel_for <class add_matrix_vector_broadcast_sg_krn <argT1, argT2, resT>>(
365
+ sycl::nd_range<1 >(gwsRange, lwsRange),
366
+ [=](sycl::nd_item<1 > ndit)
367
+ {
368
+ auto sg = ndit.get_sub_group ();
369
+ size_t gid = ndit.get_global_linear_id ();
444
370
445
- std::uint8_t sgSize = sg.get_local_range ()[0 ];
446
- size_t base = gid - sg.get_local_id ()[0 ];
371
+ std::uint8_t sgSize = sg.get_local_range ()[0 ];
372
+ size_t base = gid - sg.get_local_id ()[0 ];
447
373
448
- if (base + sgSize < n_elems) {
449
- using in_ptrT1 = sycl::multi_ptr<
450
- const argT1, sycl::access::address_space::global_space>;
451
- using in_ptrT2 = sycl::multi_ptr<
452
- const argT2, sycl::access::address_space::global_space>;
453
- using res_ptrT = sycl::multi_ptr<
454
- resT, sycl::access::address_space::global_space>;
374
+ if (base + sgSize < n_elems) {
375
+ using in_ptrT1 =
376
+ sycl::multi_ptr<const argT1,
377
+ sycl::access::address_space::global_space>;
378
+ using in_ptrT2 =
379
+ sycl::multi_ptr<const argT2,
380
+ sycl::access::address_space::global_space>;
381
+ using res_ptrT =
382
+ sycl::multi_ptr<resT,
383
+ sycl::access::address_space::global_space>;
455
384
456
- const argT1 mat_el = sg.load (in_ptrT1 (&mat[base]));
457
- const argT2 vec_el =
458
- sg.load (in_ptrT2 (&padded_vec[base % n1]));
385
+ const argT1 mat_el = sg.load (in_ptrT1 (&mat[base]));
386
+ const argT2 vec_el = sg.load (in_ptrT2 (&padded_vec[base % n1]));
459
387
460
- resT res_el = mat_el + vec_el;
388
+ resT res_el = mat_el + vec_el;
461
389
462
- sg.store (res_ptrT (&res[base]), res_el);
463
- }
464
- else {
465
- for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
466
- k += sgSize) {
467
- res[k] = mat[k] + padded_vec[k % n1];
468
- }
469
- }
390
+ sg.store (res_ptrT (&res[base]), res_el);
391
+ }
392
+ else {
393
+ for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
394
+ k += sgSize) {
395
+ res[k] = mat[k] + padded_vec[k % n1];
470
396
}
471
- );
397
+ }
398
+ });
472
399
});
473
400
474
401
sycl::event tmp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
0 commit comments