Skip to content

Commit 8e40116

Browse files
[SYCL] Implement SYCL 2020 specialization constants in Clang (#3345)
This patch implements SYCL 2020 specialization constants in the frontend. Specialization constants in SYCL 2020 are not captured by lambda and are accessed through new optional lambda argument kernel_handler. If this argument is present, a clone of the kernel_handler object should be constructed in OpenCL kernel. If the target has no native support for specialization constants, the compiler generates an OpenCL kernel argument specialization_constants_buffer of type char*. Generated kernel_handler clone should then be initialized using the generated kernel argument through __init_specialization_constants_buffer method. The generated kernel argument `specialization_constants_buffer` should also have a corresponding entry in the `kernel_signatures` structure in the integration header. The param kind for this argument should be `kernel_param_kind_t:specialization_constants_buffer` If the target has native support for specialization constants, the additional argument, and corresponding handling, need not be generated. In this case, kernel_handler local clone is default constructed. Instances of kernel_handler in sycl_kernel is then replaced to use the local clone. Signed-off-by: Elizabeth Andrews <elizabeth.andrews@intel.com>
1 parent 8b1e9ed commit 8e40116

17 files changed

+495
-82
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ class SYCLIntegrationHeader {
315315
kind_std_layout,
316316
kind_sampler,
317317
kind_pointer,
318-
kind_last = kind_pointer
318+
kind_specialization_constants_buffer,
319+
kind_last = kind_specialization_constants_buffer
319320
};
320321

321322
public:

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 177 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ enum KernelInvocationKind {
5858

5959
static constexpr llvm::StringLiteral InitMethodName = "__init";
6060
static constexpr llvm::StringLiteral InitESIMDMethodName = "__init_esimd";
61+
static constexpr llvm::StringLiteral InitSpecConstantsBuffer =
62+
"__init_specialization_constants_buffer";
6163
static constexpr llvm::StringLiteral FinalizeMethodName = "__finalize";
6264
constexpr unsigned MaxKernelArgsSize = 2048;
6365

@@ -109,6 +111,10 @@ class Util {
109111
/// specialization constant class.
110112
static bool isSyclSpecConstantType(const QualType &Ty);
111113

114+
/// Checks whether given clang type is a full specialization of the SYCL
115+
/// kernel_handler class.
116+
static bool isSyclKernelHandlerType(const QualType &Ty);
117+
112118
// Checks declaration context hierarchy.
113119
/// \param DC the context of the item to be checked.
114120
/// \param Scopes the declaration scopes leading from the item context to the
@@ -616,11 +622,16 @@ class FindPFWGLambdaFnVisitor
616622
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee());
617623
if (!M || (M->getOverloadedOperator() != OO_Call))
618624
return true;
619-
const int NumPFWGLambdaArgs = 2; // group and lambda obj
625+
626+
unsigned int NumPFWGLambdaArgs =
627+
M->getNumParams() + 1; // group, optional kernel_handler and lambda obj
620628
if (Call->getNumArgs() != NumPFWGLambdaArgs)
621629
return true;
622630
if (!Util::isSyclType(Call->getArg(1)->getType(), "group", true /*Tmpl*/))
623631
return true;
632+
if ((Call->getNumArgs() > 2) &&
633+
!Util::isSyclKernelHandlerType(Call->getArg(2)->getType()))
634+
return true;
624635
if (Call->getArg(0)->getType()->getAsCXXRecordDecl() != LambdaObjTy)
625636
return true;
626637
LambdaFn = M; // call to PFWG lambda found - record the lambda
@@ -732,12 +743,7 @@ static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) {
732743
Ctx.getTrivialTypeSourceInfo(Ty));
733744
}
734745

735-
static ParamDesc makeParamDesc(ASTContext &Ctx, const CXXBaseSpecifier &Src,
736-
QualType Ty) {
737-
// TODO: There is no name for the base available, but duplicate names are
738-
// seemingly already possible, so we'll give them all the same name for now.
739-
// This only happens with the accessor types.
740-
std::string Name = "_arg__base";
746+
static ParamDesc makeParamDesc(ASTContext &Ctx, StringRef Name, QualType Ty) {
741747
return std::make_tuple(Ty, &Ctx.Idents.get(Name),
742748
Ctx.getTrivialTypeSourceInfo(Ty));
743749
}
@@ -777,6 +783,28 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc,
777783
KernelNameType)};
778784
}
779785

786+
static bool isDefaultSPIRArch(ASTContext &Context) {
787+
llvm::Triple T = Context.getTargetInfo().getTriple();
788+
if (T.isSPIR() && T.getSubArch() == llvm::Triple::NoSubArch)
789+
return true;
790+
return false;
791+
}
792+
793+
static ParmVarDecl *getSyclKernelHandlerArg(FunctionDecl *KernelCallerFunc) {
794+
// Specialization constants in SYCL 2020 are not captured by lambda and
795+
// accessed through new optional lambda argument kernel_handler
796+
auto IsHandlerLambda = [](ParmVarDecl *PVD) {
797+
return Util::isSyclKernelHandlerType(PVD->getType());
798+
};
799+
800+
assert(llvm::count_if(KernelCallerFunc->parameters(), IsHandlerLambda) <= 1 &&
801+
"Multiple kernel_handler parameters");
802+
803+
auto KHArg = llvm::find_if(KernelCallerFunc->parameters(), IsHandlerLambda);
804+
805+
return (KHArg != KernelCallerFunc->param_end()) ? *KHArg : nullptr;
806+
}
807+
780808
// anonymous namespace so these don't get linkage.
781809
namespace {
782810

@@ -1642,10 +1670,20 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
16421670
}
16431671

16441672
void addParam(const CXXBaseSpecifier &BS, QualType FieldTy) {
1673+
// TODO: There is no name for the base available, but duplicate names are
1674+
// seemingly already possible, so we'll give them all the same name for now.
1675+
// This only happens with the accessor types.
1676+
StringRef Name = "_arg__base";
16451677
ParamDesc newParamDesc =
1646-
makeParamDesc(SemaRef.getASTContext(), BS, FieldTy);
1678+
makeParamDesc(SemaRef.getASTContext(), Name, FieldTy);
16471679
addParam(newParamDesc, FieldTy);
16481680
}
1681+
// Add a parameter with specified name and type
1682+
void addParam(StringRef Name, QualType ParamTy) {
1683+
ParamDesc newParamDesc =
1684+
makeParamDesc(SemaRef.getASTContext(), Name, ParamTy);
1685+
addParam(newParamDesc, ParamTy);
1686+
}
16491687

16501688
void addParam(ParamDesc newParamDesc, QualType FieldTy) {
16511689
// Create a new ParmVarDecl based on the new info.
@@ -1946,6 +1984,18 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
19461984
return true;
19471985
}
19481986

1987+
// Generate kernel argument to intialize specialization constants. This
1988+
// argument is only generated when the target has no native support for
1989+
// specialization constants
1990+
void handleSyclKernelHandlerType() {
1991+
ASTContext &Context = SemaRef.getASTContext();
1992+
if (isDefaultSPIRArch(Context))
1993+
return;
1994+
1995+
StringRef Name = "_arg__specialization_constants_buffer";
1996+
addParam(Name, Context.getPointerType(Context.CharTy));
1997+
}
1998+
19491999
void setBody(CompoundStmt *KB) { KernelDecl->setBody(KB); }
19502000

19512001
FunctionDecl *getKernelDecl() { return KernelDecl; }
@@ -2091,28 +2141,46 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
20912141
// pointer-struct-wrapping code to ensure that we don't try to wrap
20922142
// non-top-level pointers.
20932143
uint64_t StructDepth = 0;
2144+
VarDecl *KernelHandlerClone = nullptr;
2145+
2146+
Stmt *replaceWithLocalClone(ParmVarDecl *OriginalParam, VarDecl *LocalClone,
2147+
Stmt *FunctionBody) {
2148+
// DeclRefExpr with valid source location but with decl which is not marked
2149+
// as used is invalid.
2150+
LocalClone->setIsUsed();
2151+
std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
2152+
std::make_pair(OriginalParam, LocalClone);
2153+
KernelBodyTransform KBT(MappingPair, SemaRef);
2154+
return KBT.TransformStmt(FunctionBody).get();
2155+
}
20942156

20952157
// Using the statements/init expressions that we've created, this generates
20962158
// the kernel body compound stmt. CompoundStmt needs to know its number of
20972159
// statements in advance to allocate it, so we cannot do this as we go along.
20982160
CompoundStmt *createKernelBody() {
2161+
// Push the Kernel function scope to ensure the scope isn't empty
2162+
SemaRef.PushFunctionScope();
2163+
2164+
// Initialize kernel object local clone
20992165
assert(CollectionInitExprs.size() == 1 &&
21002166
"Should have been popped down to just the first one");
21012167
KernelObjClone->setInit(CollectionInitExprs.back());
2102-
Stmt *FunctionBody = KernelCallerFunc->getBody();
2103-
2104-
ParmVarDecl *KernelObjParam = *(KernelCallerFunc->param_begin());
21052168

2106-
// DeclRefExpr with valid source location but with decl which is not marked
2107-
// as used is invalid.
2108-
KernelObjClone->setIsUsed();
2109-
std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
2110-
std::make_pair(KernelObjParam, KernelObjClone);
2111-
2112-
// Push the Kernel function scope to ensure the scope isn't empty
2113-
SemaRef.PushFunctionScope();
2114-
KernelBodyTransform KBT(MappingPair, SemaRef);
2115-
Stmt *NewBody = KBT.TransformStmt(FunctionBody).get();
2169+
// Replace references to the kernel object in kernel body, to use the
2170+
// compiler generated local clone
2171+
Stmt *NewBody =
2172+
replaceWithLocalClone(KernelCallerFunc->getParamDecl(0), KernelObjClone,
2173+
KernelCallerFunc->getBody());
2174+
2175+
// If kernel_handler argument is passed by SYCL kernel, replace references
2176+
// to this argument in kernel body, to use the compiler generated local
2177+
// clone
2178+
if (ParmVarDecl *KernelHandlerParam =
2179+
getSyclKernelHandlerArg(KernelCallerFunc))
2180+
NewBody = replaceWithLocalClone(KernelHandlerParam, KernelHandlerClone,
2181+
NewBody);
2182+
2183+
// Use transformed body (with clones) as kernel body
21162184
BodyStmts.push_back(NewBody);
21172185

21182186
BodyStmts.insert(BodyStmts.end(), FinalizeStmts.begin(),
@@ -2412,6 +2480,39 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
24122480
return true;
24132481
}
24142482

2483+
// Generate __init call for kernel handler argument
2484+
void handleSpecialType(QualType KernelHandlerTy) {
2485+
DeclRefExpr *KernelHandlerCloneRef =
2486+
DeclRefExpr::Create(SemaRef.Context, NestedNameSpecifierLoc(),
2487+
KernelCallerSrcLoc, KernelHandlerClone, false,
2488+
DeclarationNameInfo(), KernelHandlerTy, VK_LValue);
2489+
const auto *RecordDecl =
2490+
KernelHandlerClone->getType()->getAsCXXRecordDecl();
2491+
MemberExprBases.push_back(KernelHandlerCloneRef);
2492+
createSpecialMethodCall(RecordDecl, InitSpecConstantsBuffer, BodyStmts);
2493+
MemberExprBases.pop_back();
2494+
}
2495+
2496+
void createKernelHandlerClone(ASTContext &Ctx, DeclContext *DC,
2497+
ParmVarDecl *KernelHandlerArg) {
2498+
QualType Ty = KernelHandlerArg->getType();
2499+
TypeSourceInfo *TSInfo = Ctx.getTrivialTypeSourceInfo(Ty);
2500+
KernelHandlerClone =
2501+
VarDecl::Create(Ctx, DC, KernelCallerSrcLoc, KernelCallerSrcLoc,
2502+
KernelHandlerArg->getIdentifier(), Ty, TSInfo, SC_None);
2503+
2504+
// Default initialize clone
2505+
InitializedEntity VarEntity =
2506+
InitializedEntity::InitializeVariable(KernelHandlerClone);
2507+
InitializationKind InitKind =
2508+
InitializationKind::CreateDefault(KernelCallerSrcLoc);
2509+
InitializationSequence InitSeq(SemaRef, VarEntity, InitKind, None);
2510+
ExprResult Init = InitSeq.Perform(SemaRef, VarEntity, InitKind, None);
2511+
KernelHandlerClone->setInit(
2512+
SemaRef.MaybeCreateExprWithCleanups(Init.get()));
2513+
KernelHandlerClone->setInitStyle(VarDecl::CallInit);
2514+
}
2515+
24152516
public:
24162517
static constexpr const bool VisitInsideSimpleContainers = false;
24172518
SyclKernelBodyCreator(Sema &S, SyclKernelDeclCreator &DC,
@@ -2516,6 +2617,28 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
25162617
return true;
25172618
}
25182619

2620+
// Default inits the type, then calls the init-method in the body
2621+
void handleSyclKernelHandlerType(ParmVarDecl *KernelHandlerArg) {
2622+
2623+
// Create and default initialize local clone of kernel handler
2624+
createKernelHandlerClone(SemaRef.getASTContext(),
2625+
DeclCreator.getKernelDecl(), KernelHandlerArg);
2626+
2627+
// Add declaration statement to openCL kernel body
2628+
Stmt *DS =
2629+
new (SemaRef.Context) DeclStmt(DeclGroupRef(KernelHandlerClone),
2630+
KernelCallerSrcLoc, KernelCallerSrcLoc);
2631+
BodyStmts.push_back(DS);
2632+
2633+
// Generate
2634+
// KernelHandlerClone.__init_specialization_constants_buffer(specialization_constants_buffer)
2635+
// call if target does not have native support for specialization constants.
2636+
// Here, specialization_constants_buffer is the compiler generated kernel
2637+
// argument of type char*.
2638+
if (!isDefaultSPIRArch(SemaRef.Context))
2639+
handleSpecialType(KernelHandlerArg->getType());
2640+
}
2641+
25192642
bool enterStream(const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final {
25202643
++StructDepth;
25212644
// Add a dummy init expression to catch the accessor initializers.
@@ -2870,6 +2993,22 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
28702993
return true;
28712994
}
28722995

2996+
void handleSyclKernelHandlerType(QualType Ty) {
2997+
// The compiler generated kernel argument used to initialize SYCL 2020
2998+
// specialization constants, `specialization_constants_buffer`, should
2999+
// have corresponding entry in integration header. This argument is
3000+
// only generated when target has no native support for specialization
3001+
// constants.
3002+
ASTContext &Context = SemaRef.getASTContext();
3003+
if (isDefaultSPIRArch(Context))
3004+
return;
3005+
3006+
// Offset is zero since kernel_handler argument is not part of
3007+
// kernel object (i.e. it is not captured)
3008+
addParam(Context.getPointerType(Context.CharTy),
3009+
SYCLIntegrationHeader::kind_specialization_constants_buffer, 0);
3010+
}
3011+
28733012
bool enterStream(const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final {
28743013
++StructDepth;
28753014
CurOffset += offsetOf(FD, Ty);
@@ -3257,6 +3396,13 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
32573396
KernelObjVisitor Visitor{*this};
32583397
Visitor.VisitRecordBases(KernelObj, kernel_decl, kernel_body, int_header);
32593398
Visitor.VisitRecordFields(KernelObj, kernel_decl, kernel_body, int_header);
3399+
3400+
if (ParmVarDecl *KernelHandlerArg =
3401+
getSyclKernelHandlerArg(KernelCallerFunc)) {
3402+
kernel_decl.handleSyclKernelHandlerType();
3403+
kernel_body.handleSyclKernelHandlerType(KernelHandlerArg);
3404+
int_header.handleSyclKernelHandlerType(KernelHandlerArg->getType());
3405+
}
32603406
}
32613407

32623408
void Sema::MarkDevice(void) {
@@ -3504,6 +3650,7 @@ static const char *paramKind2Str(KernelParamKind K) {
35043650
CASE(accessor);
35053651
CASE(std_layout);
35063652
CASE(sampler);
3653+
CASE(specialization_constants_buffer);
35073654
CASE(pointer);
35083655
}
35093656
return "<ERROR>";
@@ -4089,6 +4236,15 @@ bool Util::isSyclSpecConstantType(const QualType &Ty) {
40894236
return matchQualifiedTypeName(Ty, Scopes);
40904237
}
40914238

4239+
bool Util::isSyclKernelHandlerType(const QualType &Ty) {
4240+
const StringRef &Name = "kernel_handler";
4241+
std::array<DeclContextDesc, 3> Scopes = {
4242+
Util::DeclContextDesc{clang::Decl::Kind::Namespace, "cl"},
4243+
Util::DeclContextDesc{clang::Decl::Kind::Namespace, "sycl"},
4244+
Util::DeclContextDesc{Decl::Kind::CXXRecord, Name}};
4245+
return matchQualifiedTypeName(Ty, Scopes);
4246+
}
4247+
40924248
bool Util::isSyclBufferLocationType(const QualType &Ty) {
40934249
const StringRef &PropertyName = "buffer_location";
40944250
const StringRef &InstanceName = "instance";

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,23 @@ class spec_constant {
291291
} // namespace experimental
292292
} // namespace ONEAPI
293293

294+
class kernel_handler {
295+
void __init_specialization_constants_buffer(char *specialization_constants_buffer) {}
296+
};
297+
294298
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
295299
template <typename KernelName = auto_name, typename KernelType>
296-
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) {
300+
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc) { // #KernelSingleTask
297301
kernelFunc();
298302
}
299303

300304
template <typename KernelName = auto_name, typename KernelType>
301-
ATTR_SYCL_KERNEL void kernel_single_task_2017(KernelType kernelFunc) {
305+
ATTR_SYCL_KERNEL void kernel_single_task(const KernelType &kernelFunc, kernel_handler kh) {
306+
kernelFunc(kh);
307+
}
308+
309+
template <typename KernelName = auto_name, typename KernelType>
310+
ATTR_SYCL_KERNEL void kernel_single_task_2017(KernelType kernelFunc) { // #KernelSingleTask2017
302311
kernelFunc();
303312
}
304313

@@ -347,6 +356,16 @@ class handler {
347356
#endif
348357
}
349358

359+
template <typename KernelName = auto_name, typename KernelType>
360+
void single_task(const KernelType &kernelFunc, kernel_handler kh) {
361+
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
362+
#ifdef __SYCL_DEVICE_ONLY__
363+
kernel_single_task<NameT>(kernelFunc, kh);
364+
#else
365+
kernelFunc(kh);
366+
#endif
367+
}
368+
350369
template <typename KernelName = auto_name, typename KernelType>
351370
void single_task_2017(KernelType kernelFunc) {
352371
using NameT = typename get_kernel_name_t<KernelName, KernelType>::name;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-unknown-unknown -fsycl-int-header=%t.h %s -o %t.out %s -o %t.out
2+
// RUN: FileCheck -input-file=%t.h %s --check-prefix=NONATIVESUPPORT --check-prefix=ALL
3+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -o %t.out %s -o %t.out
4+
// RUN: FileCheck -input-file=%t.h %s --check-prefix=NATIVESUPPORT --check-prefix=ALL
5+
6+
// This test checks that the compiler generates required information
7+
// in integration header for kernel_handler type (SYCL 2020 specialization
8+
// constants).
9+
10+
#include "sycl.hpp"
11+
12+
using namespace cl::sycl;
13+
queue q;
14+
15+
int main() {
16+
q.submit([&](handler &h) {
17+
int a;
18+
kernel_handler kh;
19+
20+
h.single_task<class test_kernel_handler>(
21+
[=](auto) {
22+
int local = a;
23+
},
24+
kh);
25+
});
26+
}
27+
// ALL: const kernel_param_desc_t kernel_signatures[] = {
28+
// NONATIVESUPPORT: { kernel_param_kind_t::kind_specialization_constants_buffer, 8, 0 }
29+
// NATIVESUPPORT-NOT: { kernel_param_kind_t::kind_specialization_constants_buffer, 8, 0 }

0 commit comments

Comments
 (0)