15
15
#include < CL/sycl/interop_handle.hpp>
16
16
#include < CL/sycl/interop_handler.hpp>
17
17
#include < CL/sycl/kernel.hpp>
18
+ #include < CL/sycl/kernel_handler.hpp>
18
19
#include < CL/sycl/nd_item.hpp>
19
20
#include < CL/sycl/range.hpp>
20
21
@@ -122,6 +123,97 @@ class NDRDescT {
122
123
size_t Dims;
123
124
};
124
125
126
+ template <typename , typename T> struct check_fn_signature {
127
+ static_assert (std::integral_constant<T, false >::value,
128
+ " Second template parameter is required to be of function type" );
129
+ };
130
+
131
+ template <typename F, typename RetT, typename ... Args>
132
+ struct check_fn_signature <F, RetT(Args...)> {
133
+ private:
134
+ template <typename T>
135
+ static constexpr auto check (T *) -> typename std::is_same<
136
+ decltype(std::declval<T>().operator()(std::declval<Args>()...)),
137
+ RetT>::type;
138
+
139
+ template <typename > static constexpr std::false_type check (...);
140
+
141
+ using type = decltype (check<F>(0 ));
142
+
143
+ public:
144
+ static constexpr bool value = type::value;
145
+ };
146
+
147
+ template <typename F, typename ... Args>
148
+ static constexpr bool check_kernel_lambda_takes_args () {
149
+ return check_fn_signature<std::remove_reference_t <F>, void (Args...)>::value;
150
+ }
151
+
152
+ // isKernelLambdaCallableWithKernelHandlerImpl checks if LambdaArgType is void
153
+ // (e.g., in single_task), and based on that, calls
154
+ // check_kernel_lambda_takes_args with proper set of arguments. Also this type
155
+ // trait workarounds compilation error which happens only with msvc.
156
+
157
+ template <typename KernelType, typename LambdaArgType,
158
+ typename std::enable_if_t <std::is_same<LambdaArgType, void >::value>
159
+ * = nullptr >
160
+ constexpr bool isKernelLambdaCallableWithKernelHandlerImpl () {
161
+ return check_kernel_lambda_takes_args<KernelType, kernel_handler>();
162
+ }
163
+
164
+ template <typename KernelType, typename LambdaArgType,
165
+ typename std::enable_if_t <!std::is_same<LambdaArgType, void >::value>
166
+ * = nullptr >
167
+ constexpr bool isKernelLambdaCallableWithKernelHandlerImpl () {
168
+ return check_kernel_lambda_takes_args<KernelType, LambdaArgType,
169
+ kernel_handler>();
170
+ }
171
+
172
+ // Type traits to find out if kernal lambda has kernel_handler argument
173
+
174
+ template <typename KernelType>
175
+ constexpr bool isKernelLambdaCallableWithKernelHandler () {
176
+ return check_kernel_lambda_takes_args<KernelType, kernel_handler>();
177
+ }
178
+
179
+ template <typename KernelType, typename LambdaArgType>
180
+ constexpr bool isKernelLambdaCallableWithKernelHandler () {
181
+ return isKernelLambdaCallableWithKernelHandlerImpl<KernelType,
182
+ LambdaArgType>();
183
+ }
184
+
185
+ // Helpers for running kernel lambda on the host device
186
+
187
+ template <typename KernelType,
188
+ typename std::enable_if_t <isKernelLambdaCallableWithKernelHandler<
189
+ KernelType>()> * = nullptr >
190
+ constexpr void runKernelWithoutArg (KernelType KernelName) {
191
+ kernel_handler KH;
192
+ KernelName (KH);
193
+ }
194
+
195
+ template <typename KernelType,
196
+ typename std::enable_if_t <!isKernelLambdaCallableWithKernelHandler<
197
+ KernelType>()> * = nullptr >
198
+ constexpr void runKernelWithoutArg (KernelType KernelName) {
199
+ KernelName ();
200
+ }
201
+
202
+ template <typename ArgType, typename KernelType,
203
+ typename std::enable_if_t <isKernelLambdaCallableWithKernelHandler<
204
+ KernelType, ArgType>()> * = nullptr >
205
+ constexpr void runKernelWithArg (KernelType KernelName, ArgType Arg) {
206
+ kernel_handler KH;
207
+ KernelName (Arg, KH);
208
+ }
209
+
210
+ template <typename ArgType, typename KernelType,
211
+ typename std::enable_if_t <!isKernelLambdaCallableWithKernelHandler<
212
+ KernelType, ArgType>()> * = nullptr >
213
+ constexpr void runKernelWithArg (KernelType KernelName, ArgType Arg) {
214
+ KernelName (Arg);
215
+ }
216
+
125
217
// The pure virtual class aimed to store lambda/functors of any type.
126
218
class HostKernelBase {
127
219
public:
@@ -197,7 +289,7 @@ class HostKernel : public HostKernelBase {
197
289
template <class ArgT = KernelArgType>
198
290
typename detail::enable_if_t <std::is_same<ArgT, void >::value>
199
291
runOnHost (const NDRDescT &) {
200
- MKernel ( );
292
+ runKernelWithoutArg (MKernel );
201
293
}
202
294
203
295
template <class ArgT = KernelArgType>
@@ -218,18 +310,18 @@ class HostKernel : public HostKernelBase {
218
310
UpperBound[I] = Range[I] + Offset[I];
219
311
}
220
312
221
- detail::NDLoop<Dims>::iterate(/* LowerBound= */ Offset, Stride, UpperBound,
222
- [&]( const sycl::id<Dims> &ID) {
223
- sycl::item <Dims, /* Offset= */ true > Item =
224
- IDBuilder::createItem <Dims, true >(
225
- Range, ID, Offset);
226
-
227
- if (StoreLocation) {
228
- store_id (&ID);
229
- store_item (&Item);
230
- }
231
- MKernel ( ID);
232
- });
313
+ detail::NDLoop<Dims>::iterate(
314
+ /* LowerBound= */ Offset, Stride, UpperBound,
315
+ [&]( const sycl::id <Dims> &ID) {
316
+ sycl::item <Dims, /* Offset= */ true > Item =
317
+ IDBuilder::createItem<Dims, true >( Range, ID, Offset);
318
+
319
+ if (StoreLocation) {
320
+ store_id (&ID);
321
+ store_item (&Item);
322
+ }
323
+ runKernelWithArg< const sycl::id<Dims> &>(MKernel, ID);
324
+ });
233
325
}
234
326
235
327
template <class ArgT = KernelArgType>
@@ -253,7 +345,7 @@ class HostKernel : public HostKernelBase {
253
345
store_id (&ID);
254
346
store_item (&ItemWithOffset);
255
347
}
256
- MKernel ( Item);
348
+ runKernelWithArg<sycl::item<Dims, /* Offset= */ false >>(MKernel, Item);
257
349
});
258
350
}
259
351
@@ -276,18 +368,18 @@ class HostKernel : public HostKernelBase {
276
368
UpperBound[I] = Range[I] + Offset[I];
277
369
}
278
370
279
- detail::NDLoop<Dims>::iterate(/* LowerBound= */ Offset, Stride, UpperBound,
280
- [&]( const sycl::id<Dims> &ID) {
281
- sycl::item <Dims, /* Offset= */ true > Item =
282
- IDBuilder::createItem <Dims, true >(
283
- Range, ID, Offset);
284
-
285
- if (StoreLocation) {
286
- store_id (&ID);
287
- store_item (&Item);
288
- }
289
- MKernel ( Item);
290
- });
371
+ detail::NDLoop<Dims>::iterate(
372
+ /* LowerBound= */ Offset, Stride, UpperBound,
373
+ [&]( const sycl::id <Dims> &ID) {
374
+ sycl::item <Dims, /* Offset= */ true > Item =
375
+ IDBuilder::createItem<Dims, true >( Range, ID, Offset);
376
+
377
+ if (StoreLocation) {
378
+ store_id (&ID);
379
+ store_item (&Item);
380
+ }
381
+ runKernelWithArg<sycl::item<Dims, /* Offset= */ true >>(MKernel, Item);
382
+ });
291
383
}
292
384
293
385
template <class ArgT = KernelArgType>
@@ -336,7 +428,7 @@ class HostKernel : public HostKernelBase {
336
428
auto g = NDItem.get_group ();
337
429
store_group (&g);
338
430
}
339
- MKernel ( NDItem);
431
+ runKernelWithArg< const sycl::nd_item<Dims>>(MKernel, NDItem);
340
432
});
341
433
});
342
434
}
@@ -364,7 +456,7 @@ class HostKernel : public HostKernelBase {
364
456
detail::NDLoop<Dims>::iterate(NGroups, [&](const id<Dims> &GroupID) {
365
457
sycl::group<Dims> Group =
366
458
IDBuilder::createGroup<Dims>(GlobalSize, LocalSize, NGroups, GroupID);
367
- MKernel ( Group);
459
+ runKernelWithArg<sycl::group<Dims>>(MKernel, Group);
368
460
});
369
461
}
370
462
0 commit comments