Skip to content

Commit f3535f1

Browse files
Created templates for binary functions too, applied for addition
1 parent 58ff5be commit f3535f1

File tree

2 files changed

+242
-145
lines changed

2 files changed

+242
-145
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/add.hpp

Lines changed: 72 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "utils/offset_utils.hpp"
88
#include "utils/type_dispatch.hpp"
99
#include "utils/type_utils.hpp"
10+
11+
#include "kernels/elementwise_functions/common.hpp"
1012
#include <pybind11/pybind11.h>
1113

1214
namespace dpctl
@@ -20,101 +22,60 @@ namespace add
2022

2123
namespace py = pybind11;
2224
namespace td_ns = dpctl::tensor::type_dispatch;
25+
namespace tu_ns = dpctl::tensor::type_utils;
2326

24-
template <typename argT1,
25-
typename argT2,
26-
typename resT,
27-
unsigned int vec_sz = 4,
28-
unsigned int n_vecs = 2>
29-
struct AddContigFunctor
27+
template <typename argT1, typename argT2, typename resT> struct AddFunctor
3028
{
31-
private:
32-
const argT1 *in1 = nullptr;
33-
const argT2 *in2 = nullptr;
34-
resT *out = nullptr;
35-
const size_t nelems_;
36-
37-
public:
38-
AddContigFunctor(const argT1 *inp1,
39-
const argT2 *inp2,
40-
resT *res,
41-
const size_t n_elems)
42-
: in1(inp1), in2(inp2), out(res), nelems_(n_elems)
29+
30+
using supports_sg_loadstore = std::negation<
31+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
32+
using supports_vec = std::negation<
33+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
34+
35+
resT operator()(const argT1 &in1, const argT2 &in2)
4336
{
37+
return in1 + in2;
4438
}
4539

46-
void operator()(sycl::nd_item<1> ndit) const
40+
template <int vec_sz>
41+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
42+
const sycl::vec<argT2, vec_sz> &in2)
4743
{
48-
/* Each work-item processes vec_sz elements, contiguous in memory */
49-
50-
using dpctl::tensor::type_utils::is_complex;
51-
if constexpr (is_complex<argT1>::value || is_complex<argT2>::value) {
52-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
53-
size_t base = ndit.get_global_linear_id();
54-
55-
base = (base / sgSize) * sgSize * n_vecs * vec_sz + (base % sgSize);
56-
for (size_t offset = base;
57-
offset < std::min(nelems_, base + sgSize * (n_vecs * vec_sz));
58-
offset += sgSize)
59-
{
60-
out[offset] = in1[offset] + in2[offset];
61-
}
44+
auto tmp = in1 + in2;
45+
if constexpr (std::is_same_v<resT,
46+
typename decltype(tmp)::element_type>) {
47+
return tmp;
6248
}
6349
else {
64-
auto sg = ndit.get_sub_group();
65-
std::uint8_t sgSize = sg.get_local_range()[0];
66-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
67-
size_t base = n_vecs * vec_sz *
68-
(ndit.get_group(0) * ndit.get_local_range(0) +
69-
sg.get_group_id()[0] * maxsgSize);
50+
using dpctl::tensor::type_utils::vec_cast;
7051

71-
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
72-
(sgSize == maxsgSize)) {
73-
using in_ptrT1 =
74-
sycl::multi_ptr<const argT1,
75-
sycl::access::address_space::global_space>;
76-
using in_ptrT2 =
77-
sycl::multi_ptr<const argT2,
78-
sycl::access::address_space::global_space>;
79-
using out_ptrT =
80-
sycl::multi_ptr<resT,
81-
sycl::access::address_space::global_space>;
82-
sycl::vec<argT1, vec_sz> arg1_vec;
83-
sycl::vec<argT2, vec_sz> arg2_vec;
84-
sycl::vec<resT, vec_sz> res_vec;
85-
86-
#pragma unroll
87-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
88-
arg1_vec =
89-
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
90-
arg2_vec =
91-
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
92-
if constexpr (std::is_same_v<argT1, resT> &&
93-
std::is_same_v<argT2, resT>) {
94-
res_vec = arg1_vec + arg2_vec;
95-
}
96-
else {
97-
using dpctl::tensor::type_utils::vec_cast;
98-
99-
auto tmp = arg1_vec + arg2_vec;
100-
res_vec = std::move(
101-
vec_cast<resT, typename decltype(tmp)::element_type,
102-
vec_sz>(tmp));
103-
}
104-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
105-
res_vec);
106-
}
107-
}
108-
else {
109-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
110-
k += sgSize) {
111-
out[k] = in1[k] + in2[k];
112-
}
113-
}
52+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
53+
tmp);
11454
}
11555
}
11656
};
11757

58+
template <typename argT1,
59+
typename argT2,
60+
typename resT,
61+
unsigned int vec_sz = 4,
62+
unsigned int n_vecs = 2>
63+
using AddContigFunctor =
64+
elementwise_common::BinaryContigFunctor<argT1,
65+
argT2,
66+
resT,
67+
AddFunctor<argT1, argT2, resT>,
68+
vec_sz,
69+
n_vecs>;
70+
71+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
72+
using AddStridedFunctor =
73+
elementwise_common::BinaryStridedFunctor<argT1,
74+
argT2,
75+
resT,
76+
IndexerT,
77+
AddFunctor<argT1, argT2, resT>>;
78+
11879
template <typename T1, typename T2> struct AddOutputType
11980
{
12081
using value_type = typename std::disjunction< // disjunction is C++17
@@ -257,41 +218,6 @@ template <typename fnT, typename T1, typename T2> struct AddTypeMapFactory
257218
}
258219
};
259220

260-
template <typename argT1,
261-
typename argT2,
262-
typename resT,
263-
typename ThreeOffsets_IndexerT>
264-
struct AddStridedFunctor
265-
{
266-
private:
267-
const argT1 *in1 = nullptr;
268-
const argT2 *in2 = nullptr;
269-
resT *out = nullptr;
270-
ThreeOffsets_IndexerT three_offsets_indexer_;
271-
272-
public:
273-
AddStridedFunctor(const argT1 *inp1_tp,
274-
const argT2 *inp2_tp,
275-
resT *res_tp,
276-
ThreeOffsets_IndexerT inps_res_indexer)
277-
: in1(inp1_tp), in2(inp2_tp), out(res_tp),
278-
three_offsets_indexer_(inps_res_indexer)
279-
{
280-
}
281-
282-
void operator()(sycl::id<1> wid) const
283-
{
284-
const auto &three_offsets_ =
285-
three_offsets_indexer_(static_cast<py::ssize_t>(wid.get(0)));
286-
287-
const auto &inp1_offset = three_offsets_.get_first_offset();
288-
const auto &inp2_offset = three_offsets_.get_second_offset();
289-
const auto &out_offset = three_offsets_.get_third_offset();
290-
291-
out[out_offset] = in1[inp1_offset] + in2[inp2_offset];
292-
}
293-
};
294-
295221
template <typename T1, typename T2, typename resT, typename IndexerT>
296222
class add_strided_strided_kernel;
297223

@@ -435,40 +361,41 @@ sycl::event add_contig_matrix_contig_row_broadcast_impl(
435361
size_t n_groups = (n_elems + lws - 1) / lws;
436362
auto gwsRange = sycl::range<1>(n_groups * lws);
437363

438-
cgh.parallel_for<class add_matrix_vector_broadcast_sg_krn<argT1, argT2, resT>>(
439-
sycl::nd_range<1>(gwsRange, lwsRange),
440-
[=](sycl::nd_item<1> ndit)
441-
{
442-
auto sg = ndit.get_sub_group();
443-
size_t gid = ndit.get_global_linear_id();
364+
cgh.parallel_for<class add_matrix_vector_broadcast_sg_krn<argT1, argT2, resT>>(
365+
sycl::nd_range<1>(gwsRange, lwsRange),
366+
[=](sycl::nd_item<1> ndit)
367+
{
368+
auto sg = ndit.get_sub_group();
369+
size_t gid = ndit.get_global_linear_id();
444370

445-
std::uint8_t sgSize = sg.get_local_range()[0];
446-
size_t base = gid - sg.get_local_id()[0];
371+
std::uint8_t sgSize = sg.get_local_range()[0];
372+
size_t base = gid - sg.get_local_id()[0];
447373

448-
if (base + sgSize < n_elems) {
449-
using in_ptrT1 = sycl::multi_ptr<
450-
const argT1, sycl::access::address_space::global_space>;
451-
using in_ptrT2 = sycl::multi_ptr<
452-
const argT2, sycl::access::address_space::global_space>;
453-
using res_ptrT = sycl::multi_ptr<
454-
resT, sycl::access::address_space::global_space>;
374+
if (base + sgSize < n_elems) {
375+
using in_ptrT1 =
376+
sycl::multi_ptr<const argT1,
377+
sycl::access::address_space::global_space>;
378+
using in_ptrT2 =
379+
sycl::multi_ptr<const argT2,
380+
sycl::access::address_space::global_space>;
381+
using res_ptrT =
382+
sycl::multi_ptr<resT,
383+
sycl::access::address_space::global_space>;
455384

456-
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
457-
const argT2 vec_el =
458-
sg.load(in_ptrT2(&padded_vec[base % n1]));
385+
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
386+
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
459387

460-
resT res_el = mat_el + vec_el;
388+
resT res_el = mat_el + vec_el;
461389

462-
sg.store(res_ptrT(&res[base]), res_el);
463-
}
464-
else {
465-
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
466-
k += sgSize) {
467-
res[k] = mat[k] + padded_vec[k % n1];
468-
}
469-
}
390+
sg.store(res_ptrT(&res[base]), res_el);
391+
}
392+
else {
393+
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
394+
k += sgSize) {
395+
res[k] = mat[k] + padded_vec[k % n1];
470396
}
471-
);
397+
}
398+
});
472399
});
473400

474401
sycl::event tmp_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {

0 commit comments

Comments
 (0)