Skip to content

Commit d59bfdd

Browse files
committed
[SYCL] Implement braced-init-list or a number as range for queue::parallel_for
Modification: Make three different overloads for queue::parallel for to support range implicit conversion from number or braced-init-list Add tests for queue::parallel_for calls with generic lambda Signed-off-by: Ruslan Arutyunyan <ruslan.arutyunyan@intel.com>
1 parent 616a396 commit d59bfdd

File tree

4 files changed

+215
-9
lines changed

4 files changed

+215
-9
lines changed

sycl/include/CL/sycl/queue.hpp

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,10 +432,9 @@ class __SYCL_EXPORT queue {
432432
/// \param NumWorkItems is a range that specifies the work space of the kernel
433433
/// \param KernelFunc is the Kernel functor or lambda
434434
/// \param CodeLoc contains the code location of user code
435-
template <typename KernelName = detail::auto_name, typename KernelType,
436-
int Dims>
435+
template <typename KernelName = detail::auto_name, typename KernelType>
437436
event parallel_for(
438-
range<Dims> NumWorkItems, KernelType KernelFunc
437+
range<1> NumWorkItems, KernelType KernelFunc
439438
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
440439
,
441440
const detail::code_location &CodeLoc = detail::code_location::current()
@@ -444,12 +443,47 @@ class __SYCL_EXPORT queue {
444443
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
445444
const detail::code_location &CodeLoc = {};
446445
#endif
447-
return submit(
448-
[&](handler &CGH) {
449-
CGH.template parallel_for<KernelName, KernelType>(NumWorkItems,
450-
KernelFunc);
451-
},
452-
CodeLoc);
446+
return parallel_for_impl(NumWorkItems, KernelFunc, CodeLoc);
447+
}
448+
449+
/// parallel_for version with a kernel represented as a lambda + range that
450+
/// specifies global size only.
451+
///
452+
/// \param NumWorkItems is a range that specifies the work space of the kernel
453+
/// \param KernelFunc is the Kernel functor or lambda
454+
/// \param CodeLoc contains the code location of user code
455+
template <typename KernelName = detail::auto_name, typename KernelType>
456+
event parallel_for(
457+
range<2> NumWorkItems, KernelType KernelFunc
458+
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
459+
,
460+
const detail::code_location &CodeLoc = detail::code_location::current()
461+
#endif
462+
) {
463+
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
464+
const detail::code_location &CodeLoc = {};
465+
#endif
466+
return parallel_for_impl(NumWorkItems, KernelFunc, CodeLoc);
467+
}
468+
469+
/// parallel_for version with a kernel represented as a lambda + range that
470+
/// specifies global size only.
471+
///
472+
/// \param NumWorkItems is a range that specifies the work space of the kernel
473+
/// \param KernelFunc is the Kernel functor or lambda
474+
/// \param CodeLoc contains the code location of user code
475+
template <typename KernelName = detail::auto_name, typename KernelType>
476+
event parallel_for(
477+
range<3> NumWorkItems, KernelType KernelFunc
478+
#ifndef DISABLE_SYCL_INSTRUMENTATION_METADATA
479+
,
480+
const detail::code_location &CodeLoc = detail::code_location::current()
481+
#endif
482+
) {
483+
#ifdef DISABLE_SYCL_INSTRUMENTATION_METADATA
484+
const detail::code_location &CodeLoc = {};
485+
#endif
486+
return parallel_for_impl(NumWorkItems, KernelFunc, CodeLoc);
453487
}
454488

455489
/// parallel_for version with a kernel represented as a lambda + range that
@@ -716,6 +750,25 @@ class __SYCL_EXPORT queue {
716750
/// A template-free version of submit.
717751
event submit_impl(function_class<void(handler &)> CGH, queue secondQueue,
718752
const detail::code_location &CodeLoc);
753+
754+
/// parallel_for_impl with a kernel represented as a lambda + range that
755+
/// specifies global size only.
756+
///
757+
/// \param NumWorkItems is a range that specifies the work space of the kernel
758+
/// \param KernelFunc is the Kernel functor or lambda
759+
/// \param CodeLoc contains the code location of user code
760+
template <typename KernelName = detail::auto_name, typename KernelType,
761+
int Dims>
762+
event parallel_for_impl(
763+
range<Dims> NumWorkItems, KernelType KernelFunc,
764+
const detail::code_location &CodeLoc = detail::code_location::current()) {
765+
return submit(
766+
[&](handler &CGH) {
767+
CGH.template parallel_for<KernelName, KernelType>(NumWorkItems,
768+
KernelFunc);
769+
},
770+
CodeLoc);
771+
}
719772
};
720773

721774
} // namespace sycl
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// UNSUPPORTED: cuda
2+
// CUDA does not support unnamed lambdas.
3+
//
4+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -fsycl-unnamed-lambda %s -o %t.out
5+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
6+
// RUN: %ACC_RUN_PLACEHOLDER %t.out
7+
// RUN: %CPU_RUN_PLACEHOLDER %t.out
8+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
9+
10+
//==- queue_parallel_for_generic.cpp - SYCL queue parallel_for generic lambda -=//
11+
//
12+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
13+
// See https://llvm.org/LICENSE.txt for license information.
14+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
15+
//
16+
//===------------------------------------------------------------------------===//
17+
18+
#include <CL/sycl.hpp>
19+
#include <iostream>
20+
#include <type_traits>
21+
22+
int main() {
23+
sycl::queue q{};
24+
auto dev = q.get_device();
25+
auto ctx = q.get_context();
26+
constexpr int N = 8;
27+
28+
if (dev.get_info<sycl::info::device::usm_shared_allocations>()) {
29+
auto A = static_cast<int *>(sycl::malloc_shared(N * sizeof(int), dev, ctx));
30+
31+
for (int i = 0; i < N; i++) {
32+
A[i] = 1;
33+
}
34+
35+
q.parallel_for(N, [=](auto i) {
36+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
37+
"lambda arg type is unexpected");
38+
A[i]++;
39+
});
40+
41+
q.parallel_for<class Foo>({N}, [=](auto i) {
42+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
43+
"lambda arg type is unexpected");
44+
A[i]++;
45+
});
46+
47+
sycl::id<1> offset(0);
48+
q.parallel_for<class Baz>(sycl::range<1>{N}, offset, [=](auto i) {
49+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
50+
"lambda arg type is unexpected");
51+
A[i]++;
52+
});
53+
54+
sycl::nd_range<1> NDR(sycl::range<1>{N}, sycl::range<1>{2});
55+
q.parallel_for<class NDFoo>(NDR, [=](auto nd_i) {
56+
static_assert(std::is_same<decltype(nd_i), sycl::nd_item<1>>::value,
57+
"lambda arg type is unexpected");
58+
auto i = nd_i.get_global_id(0);
59+
A[i]++;
60+
});
61+
62+
q.wait();
63+
64+
for (int i = 0; i < N; i++) {
65+
if (A[i] != 5)
66+
return 1;
67+
}
68+
sycl::free(A, ctx);
69+
}
70+
71+
return 0;
72+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// UNSUPPORTED: cuda
2+
// CUDA does not support unnamed lambdas.
3+
//
4+
// RUN: %clangxx -fsycl -fsyntax-only -fsycl-unnamed-lambda %s -o %t.out
5+
6+
//==- queue_parallel_for_generic.cpp - SYCL queue parallel_for interface test -=//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===------------------------------------------------------------------------===//
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
#include <type_traits>
17+
18+
template <std::size_t... Is>
19+
void test_range_impl(sycl::queue q, std::index_sequence<Is...>,
20+
sycl::range<sizeof...(Is)> *) {
21+
constexpr auto dims = sizeof...(Is);
22+
23+
q.parallel_for(sycl::range<dims>{Is...}, [=](auto i) {
24+
static_assert(std::is_same<decltype(i), sycl::item<dims>>::value,
25+
"lambda arg type is unexpected");
26+
});
27+
}
28+
29+
template <std::size_t... Is>
30+
void test_range_impl(sycl::queue q, std::index_sequence<Is...>,
31+
sycl::nd_range<sizeof...(Is)> *) {
32+
constexpr auto dims = sizeof...(Is);
33+
34+
sycl::nd_range<dims> ndr{sycl::range<dims>{Is...}, sycl::range<dims>{Is...}};
35+
q.parallel_for(ndr, [=](auto i) {
36+
static_assert(std::is_same<decltype(i), sycl::nd_item<dims>>::value,
37+
"lambda arg type is unexpected");
38+
});
39+
}
40+
41+
template <template <int> class Range, std::size_t Dims>
42+
void test_range(sycl::queue q) {
43+
test_range_impl(q, std::make_index_sequence<Dims>{},
44+
static_cast<Range<Dims> *>(nullptr));
45+
}
46+
47+
void test_number_braced_init_list(sycl::queue q) {
48+
constexpr auto n = 1;
49+
q.parallel_for(n, [=](auto i) {
50+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
51+
"lambda arg type is unexpected");
52+
});
53+
54+
q.parallel_for({n}, [=](auto i) {
55+
static_assert(std::is_same<decltype(i), sycl::item<1>>::value,
56+
"lambda arg type is unexpected");
57+
});
58+
59+
q.parallel_for({n, n}, [=](auto i) {
60+
static_assert(std::is_same<decltype(i), sycl::item<2>>::value,
61+
"lambda arg type is unexpected");
62+
});
63+
64+
q.parallel_for({n, n, n}, [=](auto i) {
65+
static_assert(std::is_same<decltype(i), sycl::item<3>>::value,
66+
"lambda arg type is unexpected");
67+
});
68+
}
69+
70+
int main() {
71+
sycl::queue q{};
72+
73+
test_number_braced_init_list(q);
74+
75+
test_range<sycl::range, 1>(q);
76+
test_range<sycl::range, 2>(q);
77+
test_range<sycl::range, 3>(q);
78+
test_range<sycl::nd_range, 1>(q);
79+
test_range<sycl::nd_range, 2>(q);
80+
test_range<sycl::nd_range, 3>(q);
81+
}

0 commit comments

Comments
 (0)