Skip to content

Commit 25b482b

Browse files
authored
[SYCL] Propagate attributes of original kernel to wrapper kernel generated for range-rounding (#3306)
This change propagates attributes of a user-written SYCL kernel to the kernel generated as a wrapper around the original kernel. The wrapped kernel is executed with the original range rounded-up, which improves work group formation on CPUs and GPUs. Signed-off-by: rdeodhar rajiv.deodhar@intel.com
1 parent 505dccb commit 25b482b

19 files changed

+294
-136
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3195,6 +3195,9 @@ def warn_dllimport_dropped_from_inline_function : Warning<
31953195
InGroup<IgnoredAttributes>;
31963196
def warn_attribute_ignored : Warning<"%0 attribute ignored">,
31973197
InGroup<IgnoredAttributes>;
3198+
def warn_attribute_on_direct_kernel_callee_only : Warning<"%0 attribute allowed"
3199+
" only on a function directly called from a SYCL kernel function; attribute ignored">,
3200+
InGroup<IgnoredAttributes>;
31983201
def warn_nothrow_attribute_ignored : Warning<"'nothrow' attribute conflicts with"
31993202
" exception specification; attribute ignored">,
32003203
InGroup<IgnoredAttributes>;

clang/include/clang/Sema/Sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13023,6 +13023,7 @@ class Sema final {
1302313023

1302413024
bool isKnownGoodSYCLDecl(const Decl *D);
1302513025
void checkSYCLDeviceVarDecl(VarDecl *Var);
13026+
void copySYCLKernelAttrs(const CXXRecordDecl *KernelObj);
1302613027
void ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC);
1302713028
void MarkDevice();
1302813029

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 101 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,37 @@ static int64_t getIntExprValue(const Expr *E, ASTContext &Ctx) {
306306
return E->getIntegerConstantExpr(Ctx)->getSExtValue();
307307
}
308308

309+
// Collect function attributes related to SYCL.
310+
static void collectSYCLAttributes(Sema &S, FunctionDecl *FD,
311+
llvm::SmallVector<Attr *, 4> &Attrs,
312+
bool DirectlyCalled = true) {
313+
if (!FD->hasAttrs())
314+
return;
315+
316+
llvm::copy_if(FD->getAttrs(), std::back_inserter(Attrs), [](Attr *A) {
317+
// FIXME: Make this list self-adapt as new SYCL attributes are added.
318+
return isa<IntelReqdSubGroupSizeAttr, ReqdWorkGroupSizeAttr,
319+
SYCLIntelKernelArgsRestrictAttr, SYCLIntelNumSimdWorkItemsAttr,
320+
SYCLIntelSchedulerTargetFmaxMhzAttr,
321+
SYCLIntelMaxWorkGroupSizeAttr, SYCLIntelMaxGlobalWorkDimAttr,
322+
SYCLIntelNoGlobalWorkOffsetAttr, SYCLSimdAttr>(A);
323+
});
324+
325+
// Allow the kernel attribute "use_stall_enable_clusters" only on lambda
326+
// functions and function objects called directly from a kernel.
327+
// For all other cases, emit a warning and ignore.
328+
if (auto *A = FD->getAttr<SYCLIntelUseStallEnableClustersAttr>()) {
329+
if (DirectlyCalled) {
330+
Attrs.push_back(A);
331+
} else {
332+
S.Diag(A->getLocation(),
333+
diag::warn_attribute_on_direct_kernel_callee_only)
334+
<< A;
335+
FD->dropAttr<SYCLIntelUseStallEnableClustersAttr>();
336+
}
337+
}
338+
}
339+
309340
class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
310341
// Used to keep track of the constexpr depth, so we know whether to skip
311342
// diagnostics.
@@ -477,7 +508,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
477508
// Returns the kernel body function found during traversal.
478509
FunctionDecl *
479510
CollectPossibleKernelAttributes(FunctionDecl *SYCLKernel,
480-
llvm::SmallPtrSet<Attr *, 4> &Attrs) {
511+
llvm::SmallVector<Attr *, 4> &Attrs) {
481512
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;
482513
llvm::SmallPtrSet<FunctionDecl *, 16> Visited;
483514
llvm::SmallVector<ChildParentPair, 16> WorkList;
@@ -508,55 +539,23 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
508539
"function can be called");
509540
KernelBody = FD;
510541
}
542+
511543
WorkList.pop_back();
512544
if (!Visited.insert(FD).second)
513545
continue; // We've already seen this Decl
514546

515-
if (auto *A = FD->getAttr<IntelReqdSubGroupSizeAttr>())
516-
Attrs.insert(A);
517-
518-
if (auto *A = FD->getAttr<ReqdWorkGroupSizeAttr>())
519-
Attrs.insert(A);
520-
521-
if (auto *A = FD->getAttr<SYCLIntelKernelArgsRestrictAttr>())
522-
Attrs.insert(A);
523-
524-
if (auto *A = FD->getAttr<SYCLIntelNumSimdWorkItemsAttr>())
525-
Attrs.insert(A);
526-
527-
if (auto *A = FD->getAttr<SYCLIntelSchedulerTargetFmaxMhzAttr>())
528-
Attrs.insert(A);
529-
530-
if (auto *A = FD->getAttr<SYCLIntelMaxWorkGroupSizeAttr>())
531-
Attrs.insert(A);
532-
533-
if (auto *A = FD->getAttr<SYCLIntelMaxGlobalWorkDimAttr>())
534-
Attrs.insert(A);
535-
536-
if (auto *A = FD->getAttr<SYCLIntelNoGlobalWorkOffsetAttr>())
537-
Attrs.insert(A);
538-
539-
if (auto *A = FD->getAttr<SYCLSimdAttr>())
540-
Attrs.insert(A);
541-
542-
// Allow the kernel attribute "use_stall_enable_clusters" only on lambda
543-
// functions and function objects that are called directly from a kernel
544-
// (i.e. the one passed to the single_task or parallel_for functions).
545-
// For all other cases, emit a warning and ignore.
546-
if (auto *A = FD->getAttr<SYCLIntelUseStallEnableClustersAttr>()) {
547-
if (ParentFD == SYCLKernel) {
548-
Attrs.insert(A);
549-
} else {
550-
SemaRef.Diag(A->getLocation(), diag::warn_attribute_ignored) << A;
551-
FD->dropAttr<SYCLIntelUseStallEnableClustersAttr>();
552-
}
553-
}
547+
// Gather all attributes of FD that are SYCL related.
548+
// Some attributes are allowed only on lambda functions and function
549+
// objects called directly from a kernel (i.e. the one passed to the
550+
// single_task or parallel_for functions).
551+
bool DirectlyCalled = (ParentFD == SYCLKernel);
552+
collectSYCLAttributes(SemaRef, FD, Attrs, DirectlyCalled);
554553

555554
// Attribute "loop_fuse" can be applied explicitly on kernel function.
556555
// Attribute should not be propagated from device functions to kernel.
557556
if (auto *A = FD->getAttr<SYCLIntelLoopFuseAttr>()) {
558557
if (ParentFD == SYCLKernel) {
559-
Attrs.insert(A);
558+
Attrs.push_back(A);
560559
}
561560
}
562561

@@ -2058,8 +2057,8 @@ class SyclKernelArgsSizeChecker : public SyclKernelFieldHandler {
20582057
using SyclKernelFieldHandler::handleSyclHalfType;
20592058
};
20602059

2061-
static const CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
2062-
for (const auto *MD : Rec->methods()) {
2060+
static CXXMethodDecl *getOperatorParens(const CXXRecordDecl *Rec) {
2061+
for (auto *MD : Rec->methods()) {
20632062
if (MD->getOverloadedOperator() == OO_Call)
20642063
return MD;
20652064
}
@@ -3149,6 +3148,56 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
31493148
KernelFunc->setInvalidDecl();
31503149
}
31513150

3151+
// For a wrapped parallel_for, copy attributes from original
3152+
// kernel to wrapped kernel.
3153+
void Sema::copySYCLKernelAttrs(const CXXRecordDecl *KernelObj) {
3154+
// Get the operator() function of the wrapper.
3155+
CXXMethodDecl *OpParens = getOperatorParens(KernelObj);
3156+
assert(OpParens && "invalid kernel object");
3157+
3158+
typedef std::pair<FunctionDecl *, FunctionDecl *> ChildParentPair;
3159+
llvm::SmallPtrSet<FunctionDecl *, 16> Visited;
3160+
llvm::SmallVector<ChildParentPair, 16> WorkList;
3161+
WorkList.push_back({OpParens, nullptr});
3162+
FunctionDecl *KernelBody = nullptr;
3163+
3164+
CallGraph SYCLCG;
3165+
SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
3166+
while (!WorkList.empty()) {
3167+
FunctionDecl *FD = WorkList.back().first;
3168+
FunctionDecl *ParentFD = WorkList.back().second;
3169+
3170+
if ((ParentFD == OpParens) && isSYCLKernelBodyFunction(FD)) {
3171+
KernelBody = FD;
3172+
break;
3173+
}
3174+
3175+
WorkList.pop_back();
3176+
if (!Visited.insert(FD).second)
3177+
continue; // We've already seen this Decl
3178+
3179+
CallGraphNode *N = SYCLCG.getNode(FD);
3180+
if (!N)
3181+
continue;
3182+
3183+
for (const CallGraphNode *CI : *N) {
3184+
if (auto *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
3185+
Callee = Callee->getMostRecentDecl();
3186+
if (!Visited.count(Callee))
3187+
WorkList.push_back({Callee, FD});
3188+
}
3189+
}
3190+
}
3191+
3192+
assert(KernelBody && "improper parallel_for wrap");
3193+
if (KernelBody) {
3194+
llvm::SmallVector<Attr *, 4> Attrs;
3195+
collectSYCLAttributes(*this, KernelBody, Attrs);
3196+
if (!Attrs.empty())
3197+
llvm::for_each(Attrs, [OpParens](Attr *A) { OpParens->addAttr(A); });
3198+
}
3199+
}
3200+
31523201
// Generates the OpenCL kernel using KernelCallerFunc (kernel caller
31533202
// function) defined is SYCL headers.
31543203
// Generated OpenCL kernel contains the body of the kernel caller function,
@@ -3181,14 +3230,20 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
31813230
if (KernelObj->isInvalidDecl())
31823231
return;
31833232

3184-
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
3185-
31863233
// Calculate both names, since Integration headers need both.
31873234
std::string CalculatedName, StableName;
31883235
std::tie(CalculatedName, StableName) =
31893236
constructKernelName(*this, KernelCallerFunc, MC);
31903237
StringRef KernelName(getLangOpts().SYCLUnnamedLambda ? StableName
31913238
: CalculatedName);
3239+
3240+
// Attributes of a user-written SYCL kernel must be copied to the internally
3241+
// generated alternative kernel, identified by a known string in its name.
3242+
if (StableName.find("__pf_kernel_wrapper") != std::string::npos)
3243+
copySYCLKernelAttrs(KernelObj);
3244+
3245+
bool IsSIMDKernel = isESIMDKernelType(KernelObj);
3246+
31923247
SyclKernelDeclCreator kernel_decl(*this, KernelName, KernelObj->getLocation(),
31933248
KernelCallerFunc->isInlined(),
31943249
IsSIMDKernel);
@@ -3226,7 +3281,7 @@ void Sema::MarkDevice(void) {
32263281
Marker.CollectKernelSet(SYCLKernel, SYCLKernel, VisitedSet);
32273282

32283283
// Let's propagate attributes from device functions to a SYCL kernels
3229-
llvm::SmallPtrSet<Attr *, 4> Attrs;
3284+
llvm::SmallVector<Attr *, 4> Attrs;
32303285
// This function collects all kernel attributes which might be applied to
32313286
// a device functions, but need to be propagated down to callers, i.e.
32323287
// SYCL kernels

clang/test/SemaSYCL/Inputs/sycl.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,24 @@ template <typename Type>
206206
struct get_kernel_name_t<auto_name, Type> {
207207
using name = Type;
208208
};
209+
210+
// Used when parallel_for range is rounded-up.
211+
template <typename Type> class __pf_kernel_wrapper;
212+
213+
template <typename Type> struct get_kernel_wrapper_name_t {
214+
using name =
215+
__pf_kernel_wrapper<typename get_kernel_name_t<auto_name, Type>::name>;
216+
};
217+
209218
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
210219
template <typename KernelName = auto_name, typename KernelType>
211220
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) {
212221
kernelFunc();
213222
}
223+
template <typename KernelName = auto_name, typename KernelType>
224+
ATTR_SYCL_KERNEL void kernel_parallel_for(const KernelType &kernelFunc) {
225+
kernelFunc();
226+
}
214227
class handler {
215228
public:
216229
template <typename KernelName = auto_name, typename KernelType>
@@ -220,6 +233,16 @@ class handler {
220233
kernel_single_task<NameT>(kernelFunc);
221234
#else
222235
kernelFunc();
236+
#endif
237+
}
238+
template <typename KernelName = auto_name, typename KernelType>
239+
void parallel_for(const KernelType &kernelObj) {
240+
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
241+
using NameWT = typename get_kernel_wrapper_name_t<NameT>::name;
242+
#ifdef __SYCL_DEVICE_ONLY__
243+
kernel_parallel_for<NameT>(kernelObj);
244+
#else
245+
kernelObj();
223246
#endif
224247
}
225248
};

clang/test/SemaSYCL/args-size-overflow.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ queue q;
1111
using Accessor =
1212
accessor<int, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer>;
1313
#ifdef SPIR64
14-
// expected-warning@Inputs/sycl.hpp:220 {{size of kernel arguments (7994 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
14+
// expected-warning@Inputs/sycl.hpp:233 {{size of kernel arguments (7994 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
1515
#elif SPIR32
16-
// expected-warning@Inputs/sycl.hpp:220 {{size of kernel arguments (7986 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
16+
// expected-warning@Inputs/sycl.hpp:233 {{size of kernel arguments (7986 bytes) may exceed the supported maximum of 2048 bytes on some devices}}
1717
#endif
1818

1919
void use() {

clang/test/SemaSYCL/deferred-diagnostics-aux-builtin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ int main(int argc, char **argv) {
1212
_mm_prefetch("test", 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}}
1313

1414
deviceQueue.submit([&](sycl::handler &h) {
15-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<AName, (lambda}}
15+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<AName, (lambda}}
1616
h.single_task<class AName>([]() {
1717
_mm_prefetch("test", 4); // expected-error {{builtin is not supported on this target}}
1818
_mm_prefetch("test", 8); // expected-error {{argument value 8 is outside the valid range [0, 7]}} expected-error {{builtin is not supported on this target}}

clang/test/SemaSYCL/deferred-diagnostics-emit.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ template <typename T>
6464
void setup_sycl_operation(const T VA[]) {
6565

6666
deviceQueue.submit([&](sycl::handler &h) {
67-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<AName, (lambda}}
67+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<AName, (lambda}}
6868
h.single_task<class AName>([]() {
6969
// ======= Zero Length Arrays Not Allowed in Kernel ==========
7070
// expected-error@+1 {{zero-length arrays are not permitted in C++}}
@@ -156,7 +156,7 @@ int main(int argc, char **argv) {
156156

157157
// --- direct lambda testing ---
158158
deviceQueue.submit([&](sycl::handler &h) {
159-
// expected-note@Inputs/sycl.hpp:212 2 {{called by 'kernel_single_task<AName, (lambda}}
159+
// expected-note@Inputs/sycl.hpp:221 2 {{called by 'kernel_single_task<AName, (lambda}}
160160
h.single_task<class AName>([]() {
161161
// expected-error@+1 {{zero-length arrays are not permitted in C++}}
162162
int BadArray[0];

clang/test/SemaSYCL/float128.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ int main() {
7171
__float128 CapturedToDevice = 1;
7272
host_ok();
7373
deviceQueue.submit([&](sycl::handler &h) {
74-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<variables, (lambda}}
74+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<variables, (lambda}}
7575
h.single_task<class variables>([=]() {
7676
// expected-error@+1 {{'__float128' is not supported on this target}}
7777
decltype(CapturedToDevice) D;
@@ -88,7 +88,7 @@ int main() {
8888
});
8989

9090
deviceQueue.submit([&](sycl::handler &h) {
91-
// expected-note@Inputs/sycl.hpp:212 4{{called by 'kernel_single_task<functions, (lambda}}
91+
// expected-note@Inputs/sycl.hpp:221 4{{called by 'kernel_single_task<functions, (lambda}}
9292
h.single_task<class functions>([=]() {
9393
// expected-note@+1 2{{called by 'operator()'}}
9494
usage();
@@ -104,7 +104,7 @@ int main() {
104104
});
105105

106106
deviceQueue.submit([&](sycl::handler &h) {
107-
// expected-note@Inputs/sycl.hpp:212 {{called by 'kernel_single_task<ok, (lambda}}
107+
// expected-note@Inputs/sycl.hpp:221 {{called by 'kernel_single_task<ok, (lambda}}
108108
h.single_task<class ok>([=]() {
109109
// expected-note@+1 3{{used here}}
110110
Z<__float128> S;

clang/test/SemaSYCL/implicit_kernel_type.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ int main() {
2525
queue q;
2626

2727
#if defined(WARN)
28-
// expected-error@Inputs/sycl.hpp:220 {{'InvalidKernelName1' is an invalid kernel name type}}
29-
// expected-note@Inputs/sycl.hpp:220 {{'InvalidKernelName1' should be globally-visible}}
28+
// expected-error@Inputs/sycl.hpp:233 {{'InvalidKernelName1' is an invalid kernel name type}}
29+
// expected-note@Inputs/sycl.hpp:233 {{'InvalidKernelName1' should be globally-visible}}
3030
// expected-note@+8 {{in instantiation of function template specialization}}
3131
#elif defined(ERROR)
32-
// expected-error@Inputs/sycl.hpp:220 {{'InvalidKernelName1' is an invalid kernel name type}}
33-
// expected-note@Inputs/sycl.hpp:220 {{'InvalidKernelName1' should be globally-visible}}
32+
// expected-error@Inputs/sycl.hpp:233 {{'InvalidKernelName1' is an invalid kernel name type}}
33+
// expected-note@Inputs/sycl.hpp:233 {{'InvalidKernelName1' should be globally-visible}}
3434
// expected-note@+4 {{in instantiation of function template specialization}}
3535
#endif
3636
class InvalidKernelName1 {};
@@ -39,9 +39,9 @@ int main() {
3939
});
4040

4141
#if defined(WARN)
42-
// expected-warning@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
42+
// expected-warning@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
4343
#elif defined(ERROR)
44-
// expected-error@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
44+
// expected-error@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
4545
#endif
4646

4747
q.submit([&](handler &h) {
@@ -53,9 +53,9 @@ int main() {
5353
});
5454

5555
#if defined(WARN)
56-
// expected-warning@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
56+
// expected-warning@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
5757
#elif defined(ERROR)
58-
// expected-error@Inputs/sycl.hpp:220 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
58+
// expected-error@Inputs/sycl.hpp:233 {{SYCL 1.2.1 specification requires an explicit forward declaration for a kernel type name; your program may not be portable}}
5959
#endif
6060

6161
q.submit([&](handler &h) {

0 commit comments

Comments
 (0)