Skip to content

Commit 4524a6b

Browse files
[SYCL] Add DPC++ RT support for SYCL 2020 spec constants (part 1) (#3382)
This patch adds partial implementation of specialization constants in DPC++ RT: 1. Implementation of `specialization_id` class 2. Implementation of `kernel_handler` class 3. Support for user's device lambdas which take `kernel_handler` argument
1 parent 6272da6 commit 4524a6b

File tree

12 files changed

+531
-71
lines changed

12 files changed

+531
-71
lines changed

sycl/include/CL/sycl.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <CL/sycl/item.hpp>
3636
#include <CL/sycl/kernel.hpp>
3737
#include <CL/sycl/kernel_bundle.hpp>
38+
#include <CL/sycl/kernel_handler.hpp>
3839
#include <CL/sycl/marray.hpp>
3940
#include <CL/sycl/multi_ptr.hpp>
4041
#include <CL/sycl/nd_item.hpp>
@@ -48,6 +49,7 @@
4849
#include <CL/sycl/range.hpp>
4950
#include <CL/sycl/reduction.hpp>
5051
#include <CL/sycl/sampler.hpp>
52+
#include <CL/sycl/specialization_id.hpp>
5153
#include <CL/sycl/stream.hpp>
5254
#include <CL/sycl/types.hpp>
5355
#include <CL/sycl/usm.hpp>

sycl/include/CL/sycl/detail/cg_types.hpp

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <CL/sycl/interop_handle.hpp>
1616
#include <CL/sycl/interop_handler.hpp>
1717
#include <CL/sycl/kernel.hpp>
18+
#include <CL/sycl/kernel_handler.hpp>
1819
#include <CL/sycl/nd_item.hpp>
1920
#include <CL/sycl/range.hpp>
2021

@@ -122,6 +123,97 @@ class NDRDescT {
122123
size_t Dims;
123124
};
124125

126+
template <typename, typename T> struct check_fn_signature {
127+
static_assert(std::integral_constant<T, false>::value,
128+
"Second template parameter is required to be of function type");
129+
};
130+
131+
template <typename F, typename RetT, typename... Args>
132+
struct check_fn_signature<F, RetT(Args...)> {
133+
private:
134+
template <typename T>
135+
static constexpr auto check(T *) -> typename std::is_same<
136+
decltype(std::declval<T>().operator()(std::declval<Args>()...)),
137+
RetT>::type;
138+
139+
template <typename> static constexpr std::false_type check(...);
140+
141+
using type = decltype(check<F>(0));
142+
143+
public:
144+
static constexpr bool value = type::value;
145+
};
146+
147+
template <typename F, typename... Args>
148+
static constexpr bool check_kernel_lambda_takes_args() {
149+
return check_fn_signature<std::remove_reference_t<F>, void(Args...)>::value;
150+
}
151+
152+
// isKernelLambdaCallableWithKernelHandlerImpl checks if LambdaArgType is void
153+
// (e.g., in single_task), and based on that, calls
154+
// check_kernel_lambda_takes_args with proper set of arguments. Also this type
155+
// trait workarounds compilation error which happens only with msvc.
156+
157+
template <typename KernelType, typename LambdaArgType,
158+
typename std::enable_if_t<std::is_same<LambdaArgType, void>::value>
159+
* = nullptr>
160+
constexpr bool isKernelLambdaCallableWithKernelHandlerImpl() {
161+
return check_kernel_lambda_takes_args<KernelType, kernel_handler>();
162+
}
163+
164+
template <typename KernelType, typename LambdaArgType,
165+
typename std::enable_if_t<!std::is_same<LambdaArgType, void>::value>
166+
* = nullptr>
167+
constexpr bool isKernelLambdaCallableWithKernelHandlerImpl() {
168+
return check_kernel_lambda_takes_args<KernelType, LambdaArgType,
169+
kernel_handler>();
170+
}
171+
172+
// Type traits to find out if kernal lambda has kernel_handler argument
173+
174+
template <typename KernelType>
175+
constexpr bool isKernelLambdaCallableWithKernelHandler() {
176+
return check_kernel_lambda_takes_args<KernelType, kernel_handler>();
177+
}
178+
179+
template <typename KernelType, typename LambdaArgType>
180+
constexpr bool isKernelLambdaCallableWithKernelHandler() {
181+
return isKernelLambdaCallableWithKernelHandlerImpl<KernelType,
182+
LambdaArgType>();
183+
}
184+
185+
// Helpers for running kernel lambda on the host device
186+
187+
template <typename KernelType,
188+
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
189+
KernelType>()> * = nullptr>
190+
constexpr void runKernelWithoutArg(KernelType KernelName) {
191+
kernel_handler KH;
192+
KernelName(KH);
193+
}
194+
195+
template <typename KernelType,
196+
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
197+
KernelType>()> * = nullptr>
198+
constexpr void runKernelWithoutArg(KernelType KernelName) {
199+
KernelName();
200+
}
201+
202+
template <typename ArgType, typename KernelType,
203+
typename std::enable_if_t<isKernelLambdaCallableWithKernelHandler<
204+
KernelType, ArgType>()> * = nullptr>
205+
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
206+
kernel_handler KH;
207+
KernelName(Arg, KH);
208+
}
209+
210+
template <typename ArgType, typename KernelType,
211+
typename std::enable_if_t<!isKernelLambdaCallableWithKernelHandler<
212+
KernelType, ArgType>()> * = nullptr>
213+
constexpr void runKernelWithArg(KernelType KernelName, ArgType Arg) {
214+
KernelName(Arg);
215+
}
216+
125217
// The pure virtual class aimed to store lambda/functors of any type.
126218
class HostKernelBase {
127219
public:
@@ -197,7 +289,7 @@ class HostKernel : public HostKernelBase {
197289
template <class ArgT = KernelArgType>
198290
typename detail::enable_if_t<std::is_same<ArgT, void>::value>
199291
runOnHost(const NDRDescT &) {
200-
MKernel();
292+
runKernelWithoutArg(MKernel);
201293
}
202294

203295
template <class ArgT = KernelArgType>
@@ -218,18 +310,18 @@ class HostKernel : public HostKernelBase {
218310
UpperBound[I] = Range[I] + Offset[I];
219311
}
220312

221-
detail::NDLoop<Dims>::iterate(/*LowerBound=*/Offset, Stride, UpperBound,
222-
[&](const sycl::id<Dims> &ID) {
223-
sycl::item<Dims, /*Offset=*/true> Item =
224-
IDBuilder::createItem<Dims, true>(
225-
Range, ID, Offset);
226-
227-
if (StoreLocation) {
228-
store_id(&ID);
229-
store_item(&Item);
230-
}
231-
MKernel(ID);
232-
});
313+
detail::NDLoop<Dims>::iterate(
314+
/*LowerBound=*/Offset, Stride, UpperBound,
315+
[&](const sycl::id<Dims> &ID) {
316+
sycl::item<Dims, /*Offset=*/true> Item =
317+
IDBuilder::createItem<Dims, true>(Range, ID, Offset);
318+
319+
if (StoreLocation) {
320+
store_id(&ID);
321+
store_item(&Item);
322+
}
323+
runKernelWithArg<const sycl::id<Dims> &>(MKernel, ID);
324+
});
233325
}
234326

235327
template <class ArgT = KernelArgType>
@@ -253,7 +345,7 @@ class HostKernel : public HostKernelBase {
253345
store_id(&ID);
254346
store_item(&ItemWithOffset);
255347
}
256-
MKernel(Item);
348+
runKernelWithArg<sycl::item<Dims, /*Offset=*/false>>(MKernel, Item);
257349
});
258350
}
259351

@@ -276,18 +368,18 @@ class HostKernel : public HostKernelBase {
276368
UpperBound[I] = Range[I] + Offset[I];
277369
}
278370

279-
detail::NDLoop<Dims>::iterate(/*LowerBound=*/Offset, Stride, UpperBound,
280-
[&](const sycl::id<Dims> &ID) {
281-
sycl::item<Dims, /*Offset=*/true> Item =
282-
IDBuilder::createItem<Dims, true>(
283-
Range, ID, Offset);
284-
285-
if (StoreLocation) {
286-
store_id(&ID);
287-
store_item(&Item);
288-
}
289-
MKernel(Item);
290-
});
371+
detail::NDLoop<Dims>::iterate(
372+
/*LowerBound=*/Offset, Stride, UpperBound,
373+
[&](const sycl::id<Dims> &ID) {
374+
sycl::item<Dims, /*Offset=*/true> Item =
375+
IDBuilder::createItem<Dims, true>(Range, ID, Offset);
376+
377+
if (StoreLocation) {
378+
store_id(&ID);
379+
store_item(&Item);
380+
}
381+
runKernelWithArg<sycl::item<Dims, /*Offset=*/true>>(MKernel, Item);
382+
});
291383
}
292384

293385
template <class ArgT = KernelArgType>
@@ -336,7 +428,7 @@ class HostKernel : public HostKernelBase {
336428
auto g = NDItem.get_group();
337429
store_group(&g);
338430
}
339-
MKernel(NDItem);
431+
runKernelWithArg<const sycl::nd_item<Dims>>(MKernel, NDItem);
340432
});
341433
});
342434
}
@@ -364,7 +456,7 @@ class HostKernel : public HostKernelBase {
364456
detail::NDLoop<Dims>::iterate(NGroups, [&](const id<Dims> &GroupID) {
365457
sycl::group<Dims> Group =
366458
IDBuilder::createGroup<Dims>(GlobalSize, LocalSize, NGroups, GroupID);
367-
MKernel(Group);
459+
runKernelWithArg<sycl::group<Dims>>(MKernel, Group);
368460
});
369461
}
370462

sycl/include/CL/sycl/detail/kernel_desc.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ enum class kernel_param_kind_t {
2626
kind_accessor = 0,
2727
kind_std_layout = 1, // standard layout object parameters
2828
kind_sampler = 2,
29-
kind_pointer = 3
29+
kind_pointer = 3,
30+
kind_specialization_constants_buffer = 4,
3031
};
3132

3233
// describes a kernel parameter

0 commit comments

Comments
 (0)