@@ -58,6 +58,8 @@ enum KernelInvocationKind {
58
58
59
59
static constexpr llvm::StringLiteral InitMethodName = " __init" ;
60
60
static constexpr llvm::StringLiteral InitESIMDMethodName = " __init_esimd" ;
61
+ static constexpr llvm::StringLiteral InitSpecConstantsBuffer =
62
+ " __init_specialization_constants_buffer" ;
61
63
static constexpr llvm::StringLiteral FinalizeMethodName = " __finalize" ;
62
64
constexpr unsigned MaxKernelArgsSize = 2048 ;
63
65
@@ -109,6 +111,10 @@ class Util {
109
111
// / specialization constant class.
110
112
static bool isSyclSpecConstantType (const QualType &Ty);
111
113
114
+ // / Checks whether given clang type is a full specialization of the SYCL
115
+ // / kernel_handler class.
116
+ static bool isSyclKernelHandlerType (const QualType &Ty);
117
+
112
118
// Checks declaration context hierarchy.
113
119
// / \param DC the context of the item to be checked.
114
120
// / \param Scopes the declaration scopes leading from the item context to the
@@ -616,11 +622,16 @@ class FindPFWGLambdaFnVisitor
616
622
auto *M = dyn_cast<CXXMethodDecl>(Call->getDirectCallee ());
617
623
if (!M || (M->getOverloadedOperator () != OO_Call))
618
624
return true ;
619
- const int NumPFWGLambdaArgs = 2 ; // group and lambda obj
625
+
626
+ unsigned int NumPFWGLambdaArgs =
627
+ M->getNumParams () + 1 ; // group, optional kernel_handler and lambda obj
620
628
if (Call->getNumArgs () != NumPFWGLambdaArgs)
621
629
return true ;
622
630
if (!Util::isSyclType (Call->getArg (1 )->getType (), " group" , true /* Tmpl*/ ))
623
631
return true ;
632
+ if ((Call->getNumArgs () > 2 ) &&
633
+ !Util::isSyclKernelHandlerType (Call->getArg (2 )->getType ()))
634
+ return true ;
624
635
if (Call->getArg (0 )->getType ()->getAsCXXRecordDecl () != LambdaObjTy)
625
636
return true ;
626
637
LambdaFn = M; // call to PFWG lambda found - record the lambda
@@ -732,12 +743,7 @@ static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) {
732
743
Ctx.getTrivialTypeSourceInfo (Ty));
733
744
}
734
745
735
- static ParamDesc makeParamDesc (ASTContext &Ctx, const CXXBaseSpecifier &Src,
736
- QualType Ty) {
737
- // TODO: There is no name for the base available, but duplicate names are
738
- // seemingly already possible, so we'll give them all the same name for now.
739
- // This only happens with the accessor types.
740
- std::string Name = " _arg__base" ;
746
+ static ParamDesc makeParamDesc (ASTContext &Ctx, StringRef Name, QualType Ty) {
741
747
return std::make_tuple (Ty, &Ctx.Idents .get (Name),
742
748
Ctx.getTrivialTypeSourceInfo (Ty));
743
749
}
@@ -777,6 +783,28 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc,
777
783
KernelNameType)};
778
784
}
779
785
786
+ static bool isDefaultSPIRArch (ASTContext &Context) {
787
+ llvm::Triple T = Context.getTargetInfo ().getTriple ();
788
+ if (T.isSPIR () && T.getSubArch () == llvm::Triple::NoSubArch)
789
+ return true ;
790
+ return false ;
791
+ }
792
+
793
+ static ParmVarDecl *getSyclKernelHandlerArg (FunctionDecl *KernelCallerFunc) {
794
+ // Specialization constants in SYCL 2020 are not captured by lambda and
795
+ // accessed through new optional lambda argument kernel_handler
796
+ auto IsHandlerLambda = [](ParmVarDecl *PVD) {
797
+ return Util::isSyclKernelHandlerType (PVD->getType ());
798
+ };
799
+
800
+ assert (llvm::count_if (KernelCallerFunc->parameters (), IsHandlerLambda) <= 1 &&
801
+ " Multiple kernel_handler parameters" );
802
+
803
+ auto KHArg = llvm::find_if (KernelCallerFunc->parameters (), IsHandlerLambda);
804
+
805
+ return (KHArg != KernelCallerFunc->param_end ()) ? *KHArg : nullptr ;
806
+ }
807
+
780
808
// anonymous namespace so these don't get linkage.
781
809
namespace {
782
810
@@ -1642,10 +1670,20 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1642
1670
}
1643
1671
1644
1672
void addParam (const CXXBaseSpecifier &BS, QualType FieldTy) {
1673
+ // TODO: There is no name for the base available, but duplicate names are
1674
+ // seemingly already possible, so we'll give them all the same name for now.
1675
+ // This only happens with the accessor types.
1676
+ StringRef Name = " _arg__base" ;
1645
1677
ParamDesc newParamDesc =
1646
- makeParamDesc (SemaRef.getASTContext (), BS , FieldTy);
1678
+ makeParamDesc (SemaRef.getASTContext (), Name , FieldTy);
1647
1679
addParam (newParamDesc, FieldTy);
1648
1680
}
1681
+ // Add a parameter with specified name and type
1682
+ void addParam (StringRef Name, QualType ParamTy) {
1683
+ ParamDesc newParamDesc =
1684
+ makeParamDesc (SemaRef.getASTContext (), Name, ParamTy);
1685
+ addParam (newParamDesc, ParamTy);
1686
+ }
1649
1687
1650
1688
void addParam (ParamDesc newParamDesc, QualType FieldTy) {
1651
1689
// Create a new ParmVarDecl based on the new info.
@@ -1946,6 +1984,18 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1946
1984
return true ;
1947
1985
}
1948
1986
1987
+ // Generate kernel argument to intialize specialization constants. This
1988
+ // argument is only generated when the target has no native support for
1989
+ // specialization constants
1990
+ void handleSyclKernelHandlerType () {
1991
+ ASTContext &Context = SemaRef.getASTContext ();
1992
+ if (isDefaultSPIRArch (Context))
1993
+ return ;
1994
+
1995
+ StringRef Name = " _arg__specialization_constants_buffer" ;
1996
+ addParam (Name, Context.getPointerType (Context.CharTy ));
1997
+ }
1998
+
1949
1999
void setBody (CompoundStmt *KB) { KernelDecl->setBody (KB); }
1950
2000
1951
2001
FunctionDecl *getKernelDecl () { return KernelDecl; }
@@ -2091,28 +2141,46 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2091
2141
// pointer-struct-wrapping code to ensure that we don't try to wrap
2092
2142
// non-top-level pointers.
2093
2143
uint64_t StructDepth = 0 ;
2144
+ VarDecl *KernelHandlerClone = nullptr ;
2145
+
2146
+ Stmt *replaceWithLocalClone (ParmVarDecl *OriginalParam, VarDecl *LocalClone,
2147
+ Stmt *FunctionBody) {
2148
+ // DeclRefExpr with valid source location but with decl which is not marked
2149
+ // as used is invalid.
2150
+ LocalClone->setIsUsed ();
2151
+ std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
2152
+ std::make_pair (OriginalParam, LocalClone);
2153
+ KernelBodyTransform KBT (MappingPair, SemaRef);
2154
+ return KBT.TransformStmt (FunctionBody).get ();
2155
+ }
2094
2156
2095
2157
// Using the statements/init expressions that we've created, this generates
2096
2158
// the kernel body compound stmt. CompoundStmt needs to know its number of
2097
2159
// statements in advance to allocate it, so we cannot do this as we go along.
2098
2160
CompoundStmt *createKernelBody () {
2161
+ // Push the Kernel function scope to ensure the scope isn't empty
2162
+ SemaRef.PushFunctionScope ();
2163
+
2164
+ // Initialize kernel object local clone
2099
2165
assert (CollectionInitExprs.size () == 1 &&
2100
2166
" Should have been popped down to just the first one" );
2101
2167
KernelObjClone->setInit (CollectionInitExprs.back ());
2102
- Stmt *FunctionBody = KernelCallerFunc->getBody ();
2103
-
2104
- ParmVarDecl *KernelObjParam = *(KernelCallerFunc->param_begin ());
2105
2168
2106
- // DeclRefExpr with valid source location but with decl which is not marked
2107
- // as used is invalid.
2108
- KernelObjClone->setIsUsed ();
2109
- std::pair<DeclaratorDecl *, DeclaratorDecl *> MappingPair =
2110
- std::make_pair (KernelObjParam, KernelObjClone);
2111
-
2112
- // Push the Kernel function scope to ensure the scope isn't empty
2113
- SemaRef.PushFunctionScope ();
2114
- KernelBodyTransform KBT (MappingPair, SemaRef);
2115
- Stmt *NewBody = KBT.TransformStmt (FunctionBody).get ();
2169
+ // Replace references to the kernel object in kernel body, to use the
2170
+ // compiler generated local clone
2171
+ Stmt *NewBody =
2172
+ replaceWithLocalClone (KernelCallerFunc->getParamDecl (0 ), KernelObjClone,
2173
+ KernelCallerFunc->getBody ());
2174
+
2175
+ // If kernel_handler argument is passed by SYCL kernel, replace references
2176
+ // to this argument in kernel body, to use the compiler generated local
2177
+ // clone
2178
+ if (ParmVarDecl *KernelHandlerParam =
2179
+ getSyclKernelHandlerArg (KernelCallerFunc))
2180
+ NewBody = replaceWithLocalClone (KernelHandlerParam, KernelHandlerClone,
2181
+ NewBody);
2182
+
2183
+ // Use transformed body (with clones) as kernel body
2116
2184
BodyStmts.push_back (NewBody);
2117
2185
2118
2186
BodyStmts.insert (BodyStmts.end (), FinalizeStmts.begin (),
@@ -2412,6 +2480,39 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2412
2480
return true ;
2413
2481
}
2414
2482
2483
+ // Generate __init call for kernel handler argument
2484
+ void handleSpecialType (QualType KernelHandlerTy) {
2485
+ DeclRefExpr *KernelHandlerCloneRef =
2486
+ DeclRefExpr::Create (SemaRef.Context , NestedNameSpecifierLoc (),
2487
+ KernelCallerSrcLoc, KernelHandlerClone, false ,
2488
+ DeclarationNameInfo (), KernelHandlerTy, VK_LValue);
2489
+ const auto *RecordDecl =
2490
+ KernelHandlerClone->getType ()->getAsCXXRecordDecl ();
2491
+ MemberExprBases.push_back (KernelHandlerCloneRef);
2492
+ createSpecialMethodCall (RecordDecl, InitSpecConstantsBuffer, BodyStmts);
2493
+ MemberExprBases.pop_back ();
2494
+ }
2495
+
2496
+ void createKernelHandlerClone (ASTContext &Ctx, DeclContext *DC,
2497
+ ParmVarDecl *KernelHandlerArg) {
2498
+ QualType Ty = KernelHandlerArg->getType ();
2499
+ TypeSourceInfo *TSInfo = Ctx.getTrivialTypeSourceInfo (Ty);
2500
+ KernelHandlerClone =
2501
+ VarDecl::Create (Ctx, DC, KernelCallerSrcLoc, KernelCallerSrcLoc,
2502
+ KernelHandlerArg->getIdentifier (), Ty, TSInfo, SC_None);
2503
+
2504
+ // Default initialize clone
2505
+ InitializedEntity VarEntity =
2506
+ InitializedEntity::InitializeVariable (KernelHandlerClone);
2507
+ InitializationKind InitKind =
2508
+ InitializationKind::CreateDefault (KernelCallerSrcLoc);
2509
+ InitializationSequence InitSeq (SemaRef, VarEntity, InitKind, None);
2510
+ ExprResult Init = InitSeq.Perform (SemaRef, VarEntity, InitKind, None);
2511
+ KernelHandlerClone->setInit (
2512
+ SemaRef.MaybeCreateExprWithCleanups (Init.get ()));
2513
+ KernelHandlerClone->setInitStyle (VarDecl::CallInit);
2514
+ }
2515
+
2415
2516
public:
2416
2517
static constexpr const bool VisitInsideSimpleContainers = false ;
2417
2518
SyclKernelBodyCreator (Sema &S, SyclKernelDeclCreator &DC,
@@ -2516,6 +2617,28 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
2516
2617
return true ;
2517
2618
}
2518
2619
2620
+ // Default inits the type, then calls the init-method in the body
2621
+ void handleSyclKernelHandlerType (ParmVarDecl *KernelHandlerArg) {
2622
+
2623
+ // Create and default initialize local clone of kernel handler
2624
+ createKernelHandlerClone (SemaRef.getASTContext (),
2625
+ DeclCreator.getKernelDecl (), KernelHandlerArg);
2626
+
2627
+ // Add declaration statement to openCL kernel body
2628
+ Stmt *DS =
2629
+ new (SemaRef.Context ) DeclStmt (DeclGroupRef (KernelHandlerClone),
2630
+ KernelCallerSrcLoc, KernelCallerSrcLoc);
2631
+ BodyStmts.push_back (DS);
2632
+
2633
+ // Generate
2634
+ // KernelHandlerClone.__init_specialization_constants_buffer(specialization_constants_buffer)
2635
+ // call if target does not have native support for specialization constants.
2636
+ // Here, specialization_constants_buffer is the compiler generated kernel
2637
+ // argument of type char*.
2638
+ if (!isDefaultSPIRArch (SemaRef.Context ))
2639
+ handleSpecialType (KernelHandlerArg->getType ());
2640
+ }
2641
+
2519
2642
bool enterStream (const CXXRecordDecl *RD, FieldDecl *FD, QualType Ty) final {
2520
2643
++StructDepth;
2521
2644
// Add a dummy init expression to catch the accessor initializers.
@@ -2870,6 +2993,22 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
2870
2993
return true ;
2871
2994
}
2872
2995
2996
+ void handleSyclKernelHandlerType (QualType Ty) {
2997
+ // The compiler generated kernel argument used to initialize SYCL 2020
2998
+ // specialization constants, `specialization_constants_buffer`, should
2999
+ // have corresponding entry in integration header. This argument is
3000
+ // only generated when target has no native support for specialization
3001
+ // constants.
3002
+ ASTContext &Context = SemaRef.getASTContext ();
3003
+ if (isDefaultSPIRArch (Context))
3004
+ return ;
3005
+
3006
+ // Offset is zero since kernel_handler argument is not part of
3007
+ // kernel object (i.e. it is not captured)
3008
+ addParam (Context.getPointerType (Context.CharTy ),
3009
+ SYCLIntegrationHeader::kind_specialization_constants_buffer, 0 );
3010
+ }
3011
+
2873
3012
bool enterStream (const CXXRecordDecl *, FieldDecl *FD, QualType Ty) final {
2874
3013
++StructDepth;
2875
3014
CurOffset += offsetOf (FD, Ty);
@@ -3257,6 +3396,13 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
3257
3396
KernelObjVisitor Visitor{*this };
3258
3397
Visitor.VisitRecordBases (KernelObj, kernel_decl, kernel_body, int_header);
3259
3398
Visitor.VisitRecordFields (KernelObj, kernel_decl, kernel_body, int_header);
3399
+
3400
+ if (ParmVarDecl *KernelHandlerArg =
3401
+ getSyclKernelHandlerArg (KernelCallerFunc)) {
3402
+ kernel_decl.handleSyclKernelHandlerType ();
3403
+ kernel_body.handleSyclKernelHandlerType (KernelHandlerArg);
3404
+ int_header.handleSyclKernelHandlerType (KernelHandlerArg->getType ());
3405
+ }
3260
3406
}
3261
3407
3262
3408
void Sema::MarkDevice (void ) {
@@ -3504,6 +3650,7 @@ static const char *paramKind2Str(KernelParamKind K) {
3504
3650
CASE (accessor);
3505
3651
CASE (std_layout);
3506
3652
CASE (sampler);
3653
+ CASE (specialization_constants_buffer);
3507
3654
CASE (pointer);
3508
3655
}
3509
3656
return " <ERROR>" ;
@@ -4089,6 +4236,15 @@ bool Util::isSyclSpecConstantType(const QualType &Ty) {
4089
4236
return matchQualifiedTypeName (Ty, Scopes);
4090
4237
}
4091
4238
4239
+ bool Util::isSyclKernelHandlerType (const QualType &Ty) {
4240
+ const StringRef &Name = " kernel_handler" ;
4241
+ std::array<DeclContextDesc, 3 > Scopes = {
4242
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
4243
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" },
4244
+ Util::DeclContextDesc{Decl::Kind::CXXRecord, Name}};
4245
+ return matchQualifiedTypeName (Ty, Scopes);
4246
+ }
4247
+
4092
4248
bool Util::isSyclBufferLocationType (const QualType &Ty) {
4093
4249
const StringRef &PropertyName = " buffer_location" ;
4094
4250
const StringRef &InstanceName = " instance" ;
0 commit comments