Skip to content

Commit 34b93bf

Browse files
authored
[SYCL] Fix support for classes implicitly converted from items in parallel_for (#5118)
Section 4.9.4.2.2 of the SYCL 2020 specification says that the function object that represents the SYCL kernel function must take any type implicitly converted from SYCL item, representing the currently executing work-item within the range specified by the range parameter. This PR adds that support. E2E test: intel/llvm-test-suite#607
1 parent 2bc3c92 commit 34b93bf

File tree

2 files changed

+75
-6
lines changed

2 files changed

+75
-6
lines changed

sycl/include/CL/sycl/handler.hpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,14 @@ class __SYCL_EXPORT handler {
921921
AccessMode == access::mode::discard_read_write;
922922
}
923923

924+
template <int Dims, typename LambdaArgType> struct TransformUserItemType {
925+
using type = typename std::conditional<
926+
std::is_convertible<nd_item<Dims>, LambdaArgType>::value, nd_item<Dims>,
927+
typename std::conditional<
928+
std::is_convertible<item<Dims>, LambdaArgType>::value, item<Dims>,
929+
LambdaArgType>::type>::type;
930+
};
931+
924932
/// Defines and invokes a SYCL kernel function for the specified range.
925933
///
926934
/// The SYCL kernel function is defined as a lambda function or a named
@@ -939,10 +947,12 @@ class __SYCL_EXPORT handler {
939947
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
940948

941949
// If 1D kernel argument is an integral type, convert it to sycl::item<1>
942-
using TransformedArgType =
943-
typename std::conditional<std::is_integral<LambdaArgType>::value &&
944-
Dims == 1,
945-
item<Dims>, LambdaArgType>::type;
950+
// If user type is convertible from sycl::item/sycl::nd_item, use
951+
// sycl::item/sycl::nd_item to transport item information
952+
using TransformedArgType = typename std::conditional<
953+
std::is_integral<LambdaArgType>::value && Dims == 1, item<Dims>,
954+
typename TransformUserItemType<Dims, LambdaArgType>::type>::type;
955+
946956
using NameT =
947957
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
948958

@@ -1564,12 +1574,17 @@ class __SYCL_EXPORT handler {
15641574
verifyUsedKernelBundle(detail::KernelInfo<NameT>::getName());
15651575
using LambdaArgType =
15661576
sycl::detail::lambda_arg_type<KernelType, nd_item<Dims>>;
1577+
// If user type is convertible from sycl::item/sycl::nd_item, use
1578+
// sycl::item/sycl::nd_item to transport item information
1579+
using TransformedArgType =
1580+
typename TransformUserItemType<Dims, LambdaArgType>::type;
15671581
(void)ExecutionRange;
1568-
kernel_parallel_for_wrapper<NameT, LambdaArgType>(KernelFunc);
1582+
kernel_parallel_for_wrapper<NameT, TransformedArgType>(KernelFunc);
15691583
#ifndef __SYCL_DEVICE_ONLY__
15701584
detail::checkValueRange<Dims>(ExecutionRange);
15711585
MNDRDesc.set(std::move(ExecutionRange));
1572-
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
1586+
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
1587+
std::move(KernelFunc));
15731588
setType(detail::CG::Kernel);
15741589
#endif
15751590
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %RUN_ON_HOST %t.out
3+
4+
// This test performs basic check of supporting user defined class that are
5+
// implicitly converted from sycl::item/sycl::nd_item in parallel_for.
6+
7+
#include <CL/sycl.hpp>
8+
#include <iostream>
9+
10+
template <int Dimensions> class item_wrapper {
11+
public:
12+
item_wrapper(sycl::item<Dimensions> it) : m_item(it) {}
13+
14+
private:
15+
sycl::item<Dimensions> m_item;
16+
};
17+
18+
template <int Dimensions> class nd_item_wrapper {
19+
public:
20+
nd_item_wrapper(sycl::nd_item<Dimensions> it) : m_item(it) {}
21+
22+
private:
23+
sycl::nd_item<Dimensions> m_item;
24+
};
25+
26+
template <int Dimensions, typename T> class item_wrapper2 {
27+
public:
28+
item_wrapper2(sycl::item<Dimensions> it) : m_item(it), m_value(T()) {}
29+
30+
private:
31+
sycl::item<Dimensions> m_item;
32+
T m_value;
33+
};
34+
35+
template <int Dimensions, typename T> class nd_item_wrapper2 {
36+
public:
37+
nd_item_wrapper2(sycl::nd_item<Dimensions> it) : m_item(it), m_value(T()) {}
38+
39+
private:
40+
sycl::nd_item<Dimensions> m_item;
41+
T m_value;
42+
};
43+
44+
int main() {
45+
sycl::queue q;
46+
47+
q.parallel_for(sycl::range<1>{1}, [=](item_wrapper<1> item) {});
48+
q.parallel_for(sycl::nd_range<1>{1, 1}, [=](nd_item_wrapper<1> item) {});
49+
q.parallel_for(sycl::range<1>{1}, [=](item_wrapper2<1, int> item) {});
50+
q.parallel_for(sycl::nd_range<1>{1, 1},
51+
[=](nd_item_wrapper2<1, int> item) {});
52+
53+
return 0;
54+
}

0 commit comments

Comments
 (0)