1212#include " clang/AST/AST.h"
1313#include " clang/Sema/Sema.h"
1414#include " llvm/ADT/SmallVector.h"
15+ #include " TreeTransform.h"
1516
1617using namespace clang ;
1718
18- LambdaExpr *getBodyAsLambda (CXXMemberCallExpr *e) {
19- auto LastArg = e->getArg (e->getNumArgs () - 1 );
20- return dyn_cast<LambdaExpr>(LastArg);
19+ typedef llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> DeclMap;
20+
21+ class KernelBodyTransform : public TreeTransform <KernelBodyTransform> {
22+ public:
23+ KernelBodyTransform (llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> &Map,
24+ Sema &S)
25+ : TreeTransform<KernelBodyTransform>(S), DMap(Map), SemaRef(S) {}
26+ bool AlwaysRebuild () { return true ; }
27+
28+ ExprResult TransformDeclRefExpr (DeclRefExpr *DRE) {
29+ auto Ref = dyn_cast<DeclaratorDecl>(DRE->getDecl ());
30+ if (Ref) {
31+ auto NewDecl = DMap[Ref];
32+ if (NewDecl) {
33+ return DeclRefExpr::Create (
34+ SemaRef.getASTContext (), DRE->getQualifierLoc (),
35+ DRE->getTemplateKeywordLoc (), NewDecl, false , DRE->getNameInfo (),
36+ NewDecl->getType (), DRE->getValueKind ());
37+ }
38+ }
39+ return DRE;
40+ }
41+
42+ private:
43+ DeclMap DMap;
44+ Sema &SemaRef;
45+ };
46+
47+ CXXRecordDecl* getBodyAsLambda (FunctionDecl *FD) {
48+ auto FirstArg = (*FD->param_begin ());
49+ if (FirstArg)
50+ if (FirstArg->getType ()->getAsCXXRecordDecl ()->isLambda ())
51+ return FirstArg->getType ()->getAsCXXRecordDecl ();
52+ return nullptr ;
2153}
2254
2355FunctionDecl *CreateSYCLKernelFunction (ASTContext &Context, StringRef Name,
@@ -54,17 +86,16 @@ FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
5486 return Result;
5587}
5688
57- CompoundStmt *CreateSYCLKernelBody (Sema &S, CXXMemberCallExpr *e ,
89+ CompoundStmt *CreateSYCLKernelBody (Sema &S, FunctionDecl *KernelHelper ,
5890 DeclContext *DC) {
5991
6092 llvm::SmallVector<Stmt *, 16 > BodyStmts;
6193
6294 // TODO: case when kernel is functor
6395 // TODO: possible refactoring when functor case will be completed
64- LambdaExpr *LE = getBodyAsLambda (e );
65- if (LE ) {
96+ CXXRecordDecl *LC = getBodyAsLambda (KernelHelper );
97+ if (LC ) {
6698 // Create Lambda object
67- CXXRecordDecl *LC = LE->getLambdaClass ();
6899 auto LambdaVD = VarDecl::Create (
69100 S.Context , DC, SourceLocation (), SourceLocation (), LC->getIdentifier (),
70101 QualType (LC->getTypeForDecl (), 0 ), LC->getLambdaTypeInfo (), SC_None);
@@ -137,43 +168,23 @@ CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
137168 TargetFuncParam++;
138169 }
139170
140- // Create Lambda operator () call
141- FunctionDecl *LO = LE->getCallOperator ();
142- ArrayRef<ParmVarDecl *> Args = LO->parameters ();
143- llvm::SmallVector<Expr *, 16 > ParamStmts (1 );
144- ParamStmts[0 ] = dyn_cast<Expr>(LambdaDRE);
145-
146- // Collect arguments for () operator
147- for (auto Arg : Args) {
148- QualType ArgType = Arg->getOriginalType ();
149- // Declare variable for parameter and pass it to call
150- auto param_VD =
151- VarDecl::Create (S.Context , DC, SourceLocation (), SourceLocation (),
152- Arg->getIdentifier (), ArgType,
153- S.Context .getTrivialTypeSourceInfo (ArgType), SC_None);
154- Stmt *param_DS = new (S.Context )
155- DeclStmt (DeclGroupRef (param_VD), SourceLocation (), SourceLocation ());
156- BodyStmts.push_back (param_DS);
157- auto DRE = DeclRefExpr::Create (S.Context , NestedNameSpecifierLoc (),
158- SourceLocation (), param_VD, false ,
159- DeclarationNameInfo (), ArgType, VK_LValue);
160- Expr *Res = ImplicitCastExpr::Create (
161- S.Context , ArgType, CK_LValueToRValue, DRE, nullptr , VK_RValue);
162- ParamStmts.push_back (Res);
163- }
171+ // In function from headers lambda is function parameter, we need
172+ // to replace all refs to this lambda with our vardecl.
173+ // I used TreeTransform here, but I'm not sure that it is good solution
174+ // Also I used map and I'm not sure about it too.
175+ Stmt* FunctionBody = KernelHelper->getBody ();
176+ DeclMap DMap;
177+ ParmVarDecl* LambdaParam = *(KernelHelper->param_begin ());
178+ // DeclRefExpr with valid source location but with decl which is not marked
179+ // as used is invalid.
180+ LambdaVD->setIsUsed ();
181+ DMap[LambdaParam] = LambdaVD;
182+ // Without PushFunctionScope I had segfault. Maybe we also need to do pop.
183+ S.PushFunctionScope ();
184+ KernelBodyTransform KBT (DMap, S);
185+ Stmt* NewBody = KBT.TransformStmt (FunctionBody).get ();
186+ BodyStmts.push_back (NewBody);
164187
165- // Create ref for call operator
166- DeclRefExpr *DRE = new (S.Context )
167- DeclRefExpr (S.Context , LO, false , LO->getType (), VK_LValue,
168- SourceLocation ());
169- QualType ResultTy = LO->getReturnType ();
170- ExprValueKind VK = Expr::getValueKindForType (ResultTy);
171- ResultTy = ResultTy.getNonLValueExprType (S.Context );
172-
173- CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create (
174- S.Context , OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation (),
175- FPOptions (), clang::CallExpr::ADLCallKind::NotADL );
176- BodyStmts.push_back (TheCall);
177188 }
178189 return CompoundStmt::Create (S.Context , BodyStmts, SourceLocation (),
179190 SourceLocation ());
@@ -222,9 +233,9 @@ void BuildArgTys(ASTContext &Context,
222233 }
223234}
224235
225- void Sema::ConstructSYCLKernel (CXXMemberCallExpr *e ) {
236+ void Sema::ConstructSYCLKernel (FunctionDecl *KernelHelper ) {
226237 // TODO: Case when kernel is functor
227- LambdaExpr *LE = getBodyAsLambda (e );
238+ CXXRecordDecl *LE = getBodyAsLambda (KernelHelper );
228239 if (LE) {
229240
230241 llvm::SmallVector<DeclaratorDecl *, 16 > ArgDecls;
@@ -238,9 +249,8 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
238249 BuildArgTys (getASTContext (), ArgDecls, NewArgDecls, ArgTys);
239250
240251 // Get Name for our kernel.
241- FunctionDecl *FuncDecl = e->getMethodDecl ();
242252 const TemplateArgumentList *TemplateArgs =
243- FuncDecl ->getTemplateSpecializationArgs ();
253+ KernelHelper ->getTemplateSpecializationArgs ();
244254 QualType KernelNameType = TemplateArgs->get (0 ).getAsType ();
245255 std::string Name = KernelNameType.getBaseTypeIdentifier ()->getName ().str ();
246256
@@ -256,7 +266,7 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
256266 FunctionDecl *SYCLKernel =
257267 CreateSYCLKernelFunction (getASTContext (), Name, ArgTys, NewArgDecls);
258268
259- CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody (*this , e , SYCLKernel);
269+ CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody (*this , KernelHelper , SYCLKernel);
260270 SYCLKernel->setBody (SYCLKernelBody);
261271
262272 AddSyclKernel (SYCLKernel);
0 commit comments