@@ -27,16 +27,28 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a, task_t<ExecutionPolicyB, KernelB>
27
27
using new_execution_policy = named_distributed_execution_policy<
28
28
indexed_kernel_name_t <fused<ExecutionPolicyA, ExecutionPolicyB>>>;
29
29
30
- using kernel_type = std::invoke_result_t <decltype (a.get_sequence ()), handler &>;
31
- using item_type = traits::arg_type_t <kernel_type , 0 >;
30
+ using kernel_a_type = std::invoke_result_t <decltype (a.get_sequence ()), handler &>;
31
+ using context_a_type = std:: decay_t < traits::arg_type_t <kernel_a_type , 0 > >;
32
32
33
- auto seq = a.get_sequence () | b.get_sequence ();
33
+ using kernel_b_type = std::invoke_result_t <decltype (b.get_sequence ()), handler &>;
34
+ using context_b_type = std::decay_t <traits::arg_type_t <kernel_b_type, 0 >>;
35
+
36
+ using combined_context_type = combined_context_t <context_a_type, context_b_type>;
34
37
35
38
auto f = [=](handler &cgh) {
36
- const auto kernels = sequence (std::invoke (seq, cgh));
39
+ const auto kernels_a = sequence (std::invoke (a.get_sequence (), cgh));
40
+ const auto kernels_b = sequence (std::invoke (b.get_sequence (), cgh));
41
+
42
+ return [=](combined_context_type &ctx) {
43
+ context_a_type ctx_a (ctx.get_item ());
44
+ ctx_a.copy_in (ctx);
45
+
46
+ kernels_a (ctx_a);
37
47
38
- return [=](item_type item) {
39
- kernels (item);
48
+ context_b_type ctx_b{ctx_a, ctx};
49
+ kernels_b (ctx_b);
50
+
51
+ ctx.copy_out (ctx_b);
40
52
};
41
53
};
42
54
@@ -60,8 +72,16 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a,
60
72
using new_execution_policy = named_distributed_execution_policy<
61
73
indexed_kernel_name_t <fused<fused<ExecutionPolicyA, ExecutionPolicyB>, ExecutionPolicyC>>>;
62
74
63
- using kernel_type = std::invoke_result_t <decltype (a.get_sequence ()), handler &>;
64
- using item_type = traits::arg_type_t <kernel_type, 0 >;
75
+ using kernel_a_type = std::invoke_result_t <decltype (a.get_sequence ()), handler &>;
76
+ using context_a_type = std::decay_t <traits::arg_type_t <kernel_a_type, 0 >>;
77
+
78
+ using kernel_b_type = std::invoke_result_t <decltype (b.get_sequence ()), handler &>;
79
+ using context_b_type = std::decay_t <traits::arg_type_t <kernel_b_type, 0 >>;
80
+
81
+ using kernel_c_type = std::invoke_result_t <decltype (c.get_sequence ()), handler &>;
82
+ using context_c_type = std::decay_t <traits::arg_type_t <kernel_c_type, 0 >>;
83
+
84
+ using combined_context_type = combined_context_t <context_a_type, context_c_type>;
65
85
66
86
auto seq_a = a.get_sequence ();
67
87
auto seq_b = b.get_sequence ();
@@ -72,29 +92,18 @@ auto fuse(task_t<ExecutionPolicyA, KernelA> a,
72
92
const auto kernels_b = sequence (std::invoke (seq_b, cgh));
73
93
const auto kernels_c = sequence (std::invoke (seq_c, cgh));
74
94
75
- return [=](item_type item) {
76
- kernels_a (item);
77
- // data[0] = result of a
78
- // data[1] = empty
79
-
80
- // switch item context so that
81
- // the b-kernels write to the
82
- // second data store
83
- item.switch_data ();
84
- // data[0] = empty
85
- // data[1] = result of a
86
-
87
- kernels_b (item);
88
- // data[0] = result of b
89
- // data[1] = result of a
90
-
91
- // switch back to normal
92
- // data[0] = result of a
93
- // data[1] = result of b
94
- item.switch_data ();
95
-
96
- kernels_c (item);
97
- // result of c written to buffer
95
+ return [=](combined_context_type &ctx) {
96
+ context_a_type ctx_a{ctx.get_item ()};
97
+ ctx_a.copy_in (ctx);
98
+ kernels_a (ctx_a);
99
+
100
+ context_b_type ctx_b{ctx.get_item ()};
101
+ kernels_b (ctx_b);
102
+
103
+ context_c_type ctx_c{ctx_a, ctx_b};
104
+ kernels_c (ctx_c);
105
+
106
+ ctx.copy_out (ctx_c);
98
107
};
99
108
};
100
109
@@ -116,8 +125,11 @@ auto fuse_right(task_t<ExecutionPolicyB, KernelB> b,
116
125
using new_execution_policy = named_distributed_execution_policy<
117
126
indexed_kernel_name_t <fused<ExecutionPolicyB, ExecutionPolicyC>>>;
118
127
119
- using kernel_type = std::invoke_result_t <decltype (b.get_sequence ()), handler &>;
120
- using item_type = traits::arg_type_t <kernel_type, 0 >;
128
+ using kernel_b_type = std::invoke_result_t <decltype (b.get_sequence ()), handler &>;
129
+ using context_b_type = std::decay_t <traits::arg_type_t <kernel_b_type, 0 >>;
130
+
131
+ using kernel_c_type = std::invoke_result_t <decltype (c.get_sequence ()), handler &>;
132
+ using context_c_type = std::decay_t <traits::arg_type_t <kernel_c_type, 0 >>;
121
133
122
134
auto seq_b = b.get_sequence ();
123
135
auto seq_c = c.get_sequence ();
@@ -126,26 +138,12 @@ auto fuse_right(task_t<ExecutionPolicyB, KernelB> b,
126
138
const auto kernels_b = sequence (std::invoke (seq_b, cgh));
127
139
const auto kernels_c = sequence (std::invoke (seq_c, cgh));
128
140
129
- return [=](item_type item) {
130
- // switch item context so that
131
- // the b-kernels write to the
132
- // second data store
133
- item.switch_data ();
134
- // data[0] = empty
135
- // data[1] = empty
136
-
137
- kernels_b (item);
138
- // data[0] = result of b
139
- // data[1] = empty
140
-
141
- // switch back to normal
142
- // data[0] = empty
143
- // data[1] = result of b
144
- item.switch_data ();
145
-
146
- kernels_c (item);
147
- // data[0] = result of c
148
- // data[1] = empty
141
+ return [=](context_c_type &ctx_c) {
142
+ context_b_type ctx_b{ctx_c.get_item ()};
143
+ kernels_b (ctx_b);
144
+
145
+ ctx_c.template get_in <1 >() = ctx_b.get_out ();
146
+ kernels_c (ctx_c);
149
147
};
150
148
};
151
149
0 commit comments