Skip to content

Commit 510f329

Browse files
committed
implement heterogeneous item_context
adapt three-way (ha!) and right-hand kernel fusion update algorithms
1 parent 2d66530 commit 510f329

File tree

8 files changed

+361
-87
lines changed

8 files changed

+361
-87
lines changed

src/algorithms/fill.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ auto fill_impl(IteratorType<T, Rank> beg, IteratorType<T, Rank> end, const T &va
2323
return [=](celerity::handler &cgh) {
2424
auto out_acc = get_access<policy_type, mode::write, one_to_one>(cgh, beg, end);
2525

26-
return [=](item_context<Rank, T> &ctx) {
27-
out_acc[ctx[0]] = value;
26+
return [=](item_context<Rank, T(void)> &ctx) {
27+
out_acc[ctx.get_out()] = value;
2828
};
2929
};
3030
}

src/algorithms/for_each.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ auto for_each_impl(InIterator<T, Rank> beg, InIterator<T, Rank> end, const F &f)
2626
return [=](celerity::handler &cgh) {
2727
auto in_acc = get_access<policy_type, cl::sycl::access::mode::read, accessor_type>(cgh, beg, end);
2828

29-
return [=](item_context<Rank, T> &ctx) {
30-
f(ctx[0], in_acc[ctx[0]]);
29+
return [=](item_context<Rank, void(T)> &ctx) {
30+
f(ctx.get_item(), in_acc[ctx.get_in()]);
3131
};
3232
};
3333
}

src/algorithms/generate.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,16 @@ auto generate_impl(IteratorType<T, Rank> beg, IteratorType<T, Rank> end, const F
2525
return [=](celerity::handler &cgh) {
2626
auto out_acc = get_access<policy_type, mode::write, one_to_one>(cgh, beg, end);
2727

28-
if constexpr (traits::arity_v<F> == 1)
29-
{
30-
return [=](item_context<Rank, T> &ctx) {
31-
out_acc[ctx[0]] = f(ctx[0]);
32-
};
33-
}
34-
else
35-
{
36-
return [=](item_context<Rank, T> &ctx) {
37-
out_acc[ctx[0]] = f();
38-
};
39-
}
28+
return [=](item_context<Rank, T()> &ctx) {
29+
if constexpr (traits::arity_v<F> == 1)
30+
{
31+
out_acc[ctx.get_out()] = f(ctx.get_item());
32+
}
33+
else
34+
{
35+
out_acc[ctx.get_out()] = f();
36+
}
37+
};
4038
};
4139
}
4240

src/algorithms/transform.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ auto transform_impl(InIterator<T, Rank> beg, InIterator<T, Rank> end, OutIterato
2828
auto in_acc = get_access<policy_type, mode::read, accessor_type>(cgh, beg, end);
2929
auto out_acc = get_access<policy_type, mode::discard_write, one_to_one>(cgh, out, out);
3030

31-
return [=](item_context<Rank, T> &ctx) {
32-
out_acc[ctx[0]] = f(in_acc[ctx[0]]);
31+
return [=](item_context<Rank, U(T)> &ctx) {
32+
out_acc[ctx.get_out()] = f(in_acc[ctx.get_in()]);
3333
};
3434
};
3535
}
@@ -48,8 +48,8 @@ auto transform_impl(InIterator<T, Rank> beg, InIterator<T, Rank> end, OutIterato
4848
auto in_acc = get_access<policy_type, mode::read, accessor_type>(cgh, beg, end);
4949
auto out_acc = get_access<policy_type, mode::write, one_to_one>(cgh, out, out);
5050

51-
return [=](item_context<Rank, T> &ctx) {
52-
out_acc[ctx[0]] = f(ctx.get_item(), in_acc[ctx[0]]);
51+
return [=](item_context<Rank, U(T)> &ctx) {
52+
out_acc[ctx.get_out()] = f(ctx.get_item(), in_acc[ctx.get_in()]);
5353
};
5454
};
5555
}
@@ -61,12 +61,13 @@ template <typename ExecutionPolicy,
6161
typename F,
6262
typename T,
6363
typename U,
64+
typename V,
6465
int Rank,
6566
require<traits::function_traits<F>::arity == 2> = yes>
6667
auto transform_impl(FirstInputIteratorType<T, Rank> beg,
6768
FirstInputIteratorType<T, Rank> end,
6869
SecondInputIteratorType<U, Rank> beg2,
69-
OutputIteratorType<T, Rank> out,
70+
OutputIteratorType<V, Rank> out,
7071
const F &f)
7172
{
7273
using namespace traits;
@@ -82,8 +83,8 @@ auto transform_impl(FirstInputIteratorType<T, Rank> beg,
8283
auto out_acc = get_access<policy_type, mode::discard_write, one_to_one>(cgh, out, out);
8384

8485
// TODO: item_context needs to fit for both T and U
85-
return [=](item_context<Rank, T> &ctx) {
86-
out_acc[ctx[0]] = f(first_in_acc[ctx[0]], second_in_acc[ctx[1]]);
86+
return [=](item_context<Rank, V(T, U)> &ctx) {
87+
out_acc[ctx.get_out()] = f(first_in_acc[ctx.template get_in<0>()], second_in_acc[ctx.template get_in<1>()]);
8788
};
8889
};
8990
}
@@ -95,12 +96,13 @@ template <typename ExecutionPolicy,
9596
typename F,
9697
typename T,
9798
typename U,
99+
typename V,
98100
int Rank,
99101
require<traits::function_traits<F>::arity == 3> = yes>
100102
auto transform_impl(FirstInputIteratorType<T, Rank> beg,
101103
FirstInputIteratorType<T, Rank> end,
102104
SecondInputIteratorType<U, Rank> beg2,
103-
OutputIteratorType<T, Rank> out, const F &f)
105+
OutputIteratorType<V, Rank> out, const F &f)
104106
{
105107
using namespace traits;
106108
using namespace cl::sycl::access;
@@ -115,8 +117,8 @@ auto transform_impl(FirstInputIteratorType<T, Rank> beg,
115117
auto out_acc = get_access<policy_type, mode::discard_write, one_to_one>(cgh, out, out);
116118

117119
// TODO: item_context needs to fit for both T and U
118-
return [=](item_context<Rank, T> &ctx) {
119-
out_acc[ctx[0]] = f(ctx.get_item(), first_in_acc[ctx[0]], second_in_acc[ctx[1]]);
120+
return [=](item_context<Rank, V(T, U)> &ctx) {
121+
out_acc[ctx.get_out()] = f(ctx.get_item(), first_in_acc[ctx.template get_in<0>()], second_in_acc[ctx.template get_in<1>()]);
120122
};
121123
};
122124
}

src/fusion.h

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,28 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a, task_t<ExecutionPolicyB, KernelB>
2727
using new_execution_policy = named_distributed_execution_policy<
2828
indexed_kernel_name_t<fused<ExecutionPolicyA, ExecutionPolicyB>>>;
2929

30-
using kernel_type = std::invoke_result_t<decltype(a.get_sequence()), handler &>;
31-
using item_type = traits::arg_type_t<kernel_type, 0>;
30+
using kernel_a_type = std::invoke_result_t<decltype(a.get_sequence()), handler &>;
31+
using context_a_type = std::decay_t<traits::arg_type_t<kernel_a_type, 0>>;
3232

33-
auto seq = a.get_sequence() | b.get_sequence();
33+
using kernel_b_type = std::invoke_result_t<decltype(b.get_sequence()), handler &>;
34+
using context_b_type = std::decay_t<traits::arg_type_t<kernel_b_type, 0>>;
35+
36+
using combined_context_type = combined_context_t<context_a_type, context_b_type>;
3437

3538
auto f = [=](handler &cgh) {
36-
const auto kernels = sequence(std::invoke(seq, cgh));
39+
const auto kernels_a = sequence(std::invoke(a.get_sequence(), cgh));
40+
const auto kernels_b = sequence(std::invoke(b.get_sequence(), cgh));
41+
42+
return [=](combined_context_type &ctx) {
43+
context_a_type ctx_a(ctx.get_item());
44+
ctx_a.copy_in(ctx);
45+
46+
kernels_a(ctx_a);
3747

38-
return [=](item_type item) {
39-
kernels(item);
48+
context_b_type ctx_b{ctx_a, ctx};
49+
kernels_b(ctx_b);
50+
51+
ctx.copy_out(ctx_b);
4052
};
4153
};
4254

@@ -60,8 +72,16 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a,
6072
using new_execution_policy = named_distributed_execution_policy<
6173
indexed_kernel_name_t<fused<fused<ExecutionPolicyA, ExecutionPolicyB>, ExecutionPolicyC>>>;
6274

63-
using kernel_type = std::invoke_result_t<decltype(a.get_sequence()), handler &>;
64-
using item_type = traits::arg_type_t<kernel_type, 0>;
75+
using kernel_a_type = std::invoke_result_t<decltype(a.get_sequence()), handler &>;
76+
using context_a_type = std::decay_t<traits::arg_type_t<kernel_a_type, 0>>;
77+
78+
using kernel_b_type = std::invoke_result_t<decltype(b.get_sequence()), handler &>;
79+
using context_b_type = std::decay_t<traits::arg_type_t<kernel_b_type, 0>>;
80+
81+
using kernel_c_type = std::invoke_result_t<decltype(c.get_sequence()), handler &>;
82+
using context_c_type = std::decay_t<traits::arg_type_t<kernel_c_type, 0>>;
83+
84+
using combined_context_type = combined_context_t<context_a_type, context_c_type>;
6585

6686
auto seq_a = a.get_sequence();
6787
auto seq_b = b.get_sequence();
@@ -72,29 +92,18 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a,
7292
const auto kernels_b = sequence(std::invoke(seq_b, cgh));
7393
const auto kernels_c = sequence(std::invoke(seq_c, cgh));
7494

75-
return [=](item_type item) {
76-
kernels_a(item);
77-
// data[0] = result of a
78-
// data[1] = empty
79-
80-
// switch item context so that
81-
// the b-kernels write to the
82-
// second data store
83-
item.switch_data();
84-
// data[0] = empty
85-
// data[1] = result of a
86-
87-
kernels_b(item);
88-
// data[0] = result of b
89-
// data[1] = result of a
90-
91-
// switch back to normal
92-
// data[0] = result of a
93-
// data[1] = result of b
94-
item.switch_data();
95-
96-
kernels_c(item);
97-
// result of c written to buffer
95+
return [=](combined_context_type &ctx) {
96+
context_a_type ctx_a{ctx.get_item()};
97+
ctx_a.copy_in(ctx);
98+
kernels_a(ctx_a);
99+
100+
context_b_type ctx_b{ctx.get_item()};
101+
kernels_b(ctx_b);
102+
103+
context_c_type ctx_c{ctx_a, ctx_b};
104+
kernels_c(ctx_c);
105+
106+
ctx.copy_out(ctx_c);
98107
};
99108
};
100109

@@ -116,8 +125,11 @@ auto fuse_right(task_t<ExecutionPolicyB, KernelB> b,
116125
using new_execution_policy = named_distributed_execution_policy<
117126
indexed_kernel_name_t<fused<ExecutionPolicyB, ExecutionPolicyC>>>;
118127

119-
using kernel_type = std::invoke_result_t<decltype(b.get_sequence()), handler &>;
120-
using item_type = traits::arg_type_t<kernel_type, 0>;
128+
using kernel_b_type = std::invoke_result_t<decltype(b.get_sequence()), handler &>;
129+
using context_b_type = std::decay_t<traits::arg_type_t<kernel_b_type, 0>>;
130+
131+
using kernel_c_type = std::invoke_result_t<decltype(c.get_sequence()), handler &>;
132+
using context_c_type = std::decay_t<traits::arg_type_t<kernel_c_type, 0>>;
121133

122134
auto seq_b = b.get_sequence();
123135
auto seq_c = c.get_sequence();
@@ -126,26 +138,12 @@ auto fuse_right(task_t<ExecutionPolicyB, KernelB> b,
126138
const auto kernels_b = sequence(std::invoke(seq_b, cgh));
127139
const auto kernels_c = sequence(std::invoke(seq_c, cgh));
128140

129-
return [=](item_type item) {
130-
// switch item context so that
131-
// the b-kernels write to the
132-
// second data store
133-
item.switch_data();
134-
// data[0] = empty
135-
// data[1] = empty
136-
137-
kernels_b(item);
138-
// data[0] = result of b
139-
// data[1] = empty
140-
141-
// switch back to normal
142-
// data[0] = empty
143-
// data[1] = result of b
144-
item.switch_data();
145-
146-
kernels_c(item);
147-
// data[0] = result of c
148-
// data[1] = empty
141+
return [=](context_c_type &ctx_c) {
142+
context_b_type ctx_b{ctx_c.get_item()};
143+
kernels_b(ctx_b);
144+
145+
ctx_c.template get_in<1>() = ctx_b.get_out();
146+
kernels_c(ctx_c);
149147
};
150148
};
151149

0 commit comments

Comments
 (0)