Skip to content

Commit 0508a19

Browse files
[SYCL][E2E] Fix use-after-free in reduction_range_N_vars.cpp (#11499)
Same as we fixed in #9112.
1 parent 8d4d91c commit 0508a19

File tree

1 file changed

+16
-24
lines changed

1 file changed

+16
-24
lines changed

sycl/test-e2e/Reduction/reduction_range_N_vars.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,30 +128,22 @@ int test(queue &Q, size_t NWorkItems, RedTys... Reds) {
128128
Q.submit([&](handler &CGH) {
129129
auto InAcc = sycl::detail::make_tuple(
130130
accessor(Reds.InBuf, CGH, sycl::read_only)...);
131-
auto SyclReds = std::forward_as_tuple(Reds.createRed(CGH)...);
132-
std::apply(
133-
[&](auto... SyclReds) {
134-
CGH.parallel_for<Name>(
135-
Range, SyclReds..., [=](id<1> It, auto &...Reducers) {
136-
static_assert(sizeof...(Reducers) == 4 ||
137-
sizeof...(Reducers) == 2);
138-
// No C++20, so don't have explicit template param lists in
139-
// lambda and can't unfold std::integer_sequence to write
140-
// generic code here.
141-
auto ReducersTuple = std::forward_as_tuple(Reducers...);
142-
size_t I = It.get(0);
143-
144-
std::get<0>(ReducersTuple).combine(std::get<0>(InAcc)[I]);
145-
std::get<1>(ReducersTuple).combine(std::get<1>(InAcc)[I]);
146-
if constexpr (sizeof...(Reds) == 4) {
147-
std::get<2>(ReducersTuple).combine(std::get<2>(InAcc)[I]);
148-
std::get<3>(ReducersTuple).combine(std::get<3>(InAcc)[I]);
149-
}
150-
151-
return;
152-
});
153-
},
154-
SyclReds);
131+
CGH.parallel_for<Name>(
132+
Range, Reds.createRed(CGH)..., [=](id<1> It, auto &...Reducers) {
133+
static_assert(sizeof...(Reducers) == 4 || sizeof...(Reducers) == 2);
134+
// No C++20, so don't have explicit template param lists in
135+
// lambda and can't unfold std::integer_sequence to write
136+
// generic code here.
137+
auto ReducersTuple = std::forward_as_tuple(Reducers...);
138+
size_t I = It.get(0);
139+
140+
std::get<0>(ReducersTuple).combine(std::get<0>(InAcc)[I]);
141+
std::get<1>(ReducersTuple).combine(std::get<1>(InAcc)[I]);
142+
if constexpr (sizeof...(Reds) == 4) {
143+
std::get<2>(ReducersTuple).combine(std::get<2>(InAcc)[I]);
144+
std::get<3>(ReducersTuple).combine(std::get<3>(InAcc)[I]);
145+
}
146+
});
155147
}).wait();
156148

157149
int NumErrors = (0 + ... + Reds.checkResult(Range));

0 commit comments

Comments
 (0)