Skip to content

Commit 03354a2

Browse files
committed
[SYCL] Add SYCL Kernel entry point generation.
Main changes are: 1. Added of parallel_for and single_task kernel invoking functions search. 2. Added kernel name extraction from single_task/parallel_for template parameter. 3. Added SYCL kernel entry point generation. 4. Non-kernel code is not emmited for sycl device now. Signed-off-by: Vladimir Lazarev <vladimir.lazarev@intel.com>
1 parent f509e63 commit 03354a2

File tree

6 files changed

+261
-0
lines changed

6 files changed

+261
-0
lines changed

clang/include/clang/Sema/Sema.h

+11
Original file line numberDiff line numberDiff line change
@@ -10842,6 +10842,17 @@ class Sema {
1084210842
Expr *E,
1084310843
llvm::function_ref<void(Expr *, RecordDecl *, FieldDecl *, CharUnits)>
1084410844
Action);
10845+
10846+
private:
10847+
// We store SYCL Kernels here and handle separately -- which is a hack.
10848+
// FIXME: It would be best to refactor this.
10849+
SmallVector<Decl*, 4> SyclKernel;
10850+
10851+
public:
10852+
void AddSyclKernel(Decl * d) { SyclKernel.push_back(d); }
10853+
SmallVector<Decl*, 4> &SyclKernels() { return SyclKernel; }
10854+
10855+
void ConstructSYCLKernel(CXXMemberCallExpr* e);
1084510856
};
1084610857

1084710858
/// RAII object that enters a new expression evaluation context.

clang/lib/CodeGen/CodeGenModule.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,11 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
21282128
if (Global->hasAttr<IFuncAttr>())
21292129
return emitIFuncDefinition(GD);
21302130

2131+
if (LangOpts.SYCL) {
2132+
if (!Global->hasAttr<OpenCLKernelAttr>())
2133+
return;
2134+
}
2135+
21312136
// If this is a cpu_dispatch multiversion function, emit the resolver.
21322137
if (Global->hasAttr<CPUDispatchAttr>())
21332138
return emitCPUDispatchDefinition(GD);

clang/lib/Parse/ParseAST.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ void clang::ParseAST(Sema &S, bool PrintStats, bool SkipFunctionBodies) {
167167
for (Decl *D : S.WeakTopLevelDecls())
168168
Consumer->HandleTopLevelDecl(DeclGroupRef(D));
169169

170+
if (S.getLangOpts().SYCL) {
171+
for (Decl *D : S.SyclKernels()) {
172+
Consumer->HandleTopLevelDecl(DeclGroupRef(D));
173+
}
174+
}
175+
170176
Consumer->HandleTranslationUnit(S.getASTContext());
171177

172178
// Finalize the template instantiation observer chain.

clang/lib/Sema/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_clang_library(clangSema
5050
SemaStmt.cpp
5151
SemaStmtAsm.cpp
5252
SemaStmtAttr.cpp
53+
SemaSYCL.cpp
5354
SemaTemplate.cpp
5455
SemaTemplateDeduction.cpp
5556
SemaTemplateInstantiate.cpp

clang/lib/Sema/SemaOverload.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -13012,6 +13012,15 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
1301213012
CXXMemberCallExpr::Create(Context, MemExprE, Args, ResultType, VK,
1301313013
RParenLoc, Proto->getNumParams());
1301413014

13015+
if (getLangOpts().SYCL) {
13016+
auto Func = TheCall->getMethodDecl();
13017+
auto Name = Func->getQualifiedNameAsString();
13018+
if (Name == "cl::sycl::handler::parallel_for" ||
13019+
Name == "cl::sycl::handler::single_task") {
13020+
ConstructSYCLKernel(TheCall);
13021+
}
13022+
}
13023+
1301513024
// Check for a valid return type.
1301613025
if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(),
1301713026
TheCall, Method))

clang/lib/Sema/SemaSYCL.cpp

+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
//===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
2+
//
3+
// The LLVM Compiler Infrastructure
4+
//
5+
// This file is distributed under the University of Illinois Open Source
6+
// License. See LICENSE.TXT for details.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
// This implements Semantic Analysis for SYCL constructs.
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "clang/AST/AST.h"
13+
#include "clang/Sema/Sema.h"
14+
#include "llvm/ADT/SmallVector.h"
15+
16+
using namespace clang;
17+
18+
LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) {
19+
auto LastArg = e->getArg(e->getNumArgs() - 1);
20+
return dyn_cast<LambdaExpr>(LastArg);
21+
}
22+
23+
FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
24+
ArrayRef<QualType> ArgTys,
25+
ArrayRef<DeclaratorDecl *> ArgDecls) {
26+
27+
DeclContext *DC = Context.getTranslationUnitDecl();
28+
FunctionProtoType::ExtProtoInfo Info;
29+
QualType RetTy = Context.VoidTy;
30+
QualType FuncTy = Context.getFunctionType(RetTy, ArgTys, Info);
31+
DeclarationName DN = DeclarationName(&Context.Idents.get(Name));
32+
FunctionDecl *Result = FunctionDecl::Create(
33+
Context, DC, SourceLocation(), SourceLocation(), DN, FuncTy,
34+
Context.getTrivialTypeSourceInfo(RetTy), SC_None);
35+
llvm::SmallVector<ParmVarDecl *, 16> Params;
36+
int i = 0;
37+
for (auto ArgTy : ArgTys) {
38+
auto P =
39+
ParmVarDecl::Create(Context, Result, SourceLocation(), SourceLocation(),
40+
ArgDecls[i]->getIdentifier(), ArgTy,
41+
ArgDecls[i]->getTypeSourceInfo(), SC_None, 0);
42+
P->setScopeInfo(0, i++);
43+
P->setIsUsed();
44+
Params.push_back(P);
45+
}
46+
Result->setParams(Params);
47+
// TODO: Add SYCL specific attribute for kernel and all functions called
48+
// by kernel.
49+
Result->addAttr(OpenCLKernelAttr::CreateImplicit(Context));
50+
Result->addAttr(AsmLabelAttr::CreateImplicit(Context, Name));
51+
return Result;
52+
}
53+
54+
CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
55+
DeclContext *DC) {
56+
57+
llvm::SmallVector<Stmt *, 16> BodyStmts;
58+
59+
// TODO: case when kernel is functor
60+
// TODO: possible refactoring when functor case will be completed
61+
LambdaExpr *LE = getBodyAsLambda(e);
62+
if (LE) {
63+
// Create Lambda object
64+
CXXRecordDecl *LC = LE->getLambdaClass();
65+
auto Lambda_VD = VarDecl::Create(
66+
S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(),
67+
QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None);
68+
Stmt *DS = new (S.Context)
69+
DeclStmt(DeclGroupRef(Lambda_VD), SourceLocation(), SourceLocation());
70+
BodyStmts.push_back(DS);
71+
auto Lambda_DRE = DeclRefExpr::Create(
72+
S.Context, NestedNameSpecifierLoc(), SourceLocation(), Lambda_VD, false,
73+
DeclarationNameInfo(), QualType(LC->getTypeForDecl(), 0), VK_LValue);
74+
75+
// Init Lambda fields
76+
llvm::SmallVector<Expr *, 16> InitCaptures;
77+
78+
auto TargetFunc = dyn_cast<FunctionDecl>(DC);
79+
auto TargetFuncParam =
80+
TargetFunc->param_begin(); // Iterator to ParamVarDecl (VarDecl)
81+
for (auto CaptureField : LE->captures()) {
82+
VarDecl *CapturedVar =
83+
CaptureField
84+
.getCapturedVar(); // accessor, need to do setInit for this
85+
QualType ParamType = (*TargetFuncParam)->getOriginalType();
86+
auto DRE = DeclRefExpr::Create(
87+
S.Context, NestedNameSpecifierLoc(), SourceLocation(),
88+
*TargetFuncParam, false, DeclarationNameInfo(), ParamType, VK_LValue);
89+
90+
Expr *Res = ImplicitCastExpr::Create(
91+
S.Context, ParamType, CK_LValueToRValue, DRE, nullptr, VK_RValue);
92+
93+
Expr *InitCapture = new (S.Context) InitListExpr(
94+
S.Context, SourceLocation(), /*initExprs*/ Res, SourceLocation());
95+
CapturedVar->setInit(InitCapture);
96+
InitCapture->setType(CapturedVar->getType());
97+
InitCaptures.push_back(InitCapture);
98+
TargetFuncParam++;
99+
}
100+
101+
Expr *InitLambdaCaptures = new (S.Context)
102+
InitListExpr(S.Context, SourceLocation(), /*initExprs*/ InitCaptures,
103+
SourceLocation());
104+
InitLambdaCaptures->setType(Lambda_VD->getType());
105+
Lambda_VD->setInit(InitLambdaCaptures);
106+
107+
// Create Lambda operator () call
108+
FunctionDecl *LO = LE->getCallOperator();
109+
ArrayRef<ParmVarDecl *> Args = LO->parameters();
110+
llvm::SmallVector<Expr *, 16> ParamStmts(1);
111+
ParamStmts[0] = dyn_cast<Expr>(Lambda_DRE);
112+
113+
// Collect arguments for () operator
114+
for (auto Arg : Args) {
115+
QualType ArgType = Arg->getOriginalType();
116+
// Declare variable for parameter and pass it to call
117+
auto param_VD =
118+
VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(),
119+
Arg->getIdentifier(), ArgType,
120+
S.Context.getTrivialTypeSourceInfo(ArgType), SC_None);
121+
Stmt *param_DS = new (S.Context)
122+
DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation());
123+
BodyStmts.push_back(param_DS);
124+
auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
125+
SourceLocation(), param_VD, false,
126+
DeclarationNameInfo(), ArgType, VK_LValue);
127+
ParamStmts.push_back(DRE);
128+
}
129+
130+
// Create ref for call operator
131+
DeclRefExpr *DRE = new (S.Context)
132+
DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue,
133+
SourceLocation());
134+
QualType ResultTy = LO->getReturnType();
135+
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
136+
ResultTy = ResultTy.getNonLValueExprType(S.Context);
137+
138+
CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create(
139+
S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(),
140+
FPOptions(), clang::CallExpr::ADLCallKind::NotADL );
141+
BodyStmts.push_back(TheCall);
142+
}
143+
return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(),
144+
SourceLocation());
145+
}
146+
147+
void BuildArgTys(ASTContext &Context,
148+
llvm::SmallVector<DeclaratorDecl *, 16> &ArgDecls,
149+
llvm::SmallVector<DeclaratorDecl *, 16> &NewArgDecls,
150+
llvm::SmallVector<QualType, 16> &ArgTys) {
151+
for (auto V : ArgDecls) {
152+
QualType ArgTy = V->getType();
153+
QualType ActualArgType = ArgTy;
154+
StringRef Name = ArgTy.getBaseTypeIdentifier()->getName();
155+
// TODO: harden this check with additional validation that this class is
156+
// declared in cl::sycl namespace
157+
if (std::string(Name) == "accessor") {
158+
if (const auto *RecordDecl = ArgTy->getAsCXXRecordDecl()) {
159+
const auto *TemplateDecl =
160+
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
161+
if (TemplateDecl) {
162+
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
163+
Qualifiers Quals = PointeeType.getQualifiers();
164+
Quals.setAddressSpace(LangAS::opencl_global);
165+
PointeeType =
166+
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
167+
QualType PointerType = Context.getPointerType(PointeeType);
168+
ActualArgType =
169+
Context.getQualifiedType(PointerType.getUnqualifiedType(), Quals);
170+
}
171+
}
172+
}
173+
DeclContext *DC = Context.getTranslationUnitDecl();
174+
175+
IdentifierInfo *VarName = 0;
176+
SmallString<8> Str;
177+
llvm::raw_svector_ostream OS(Str);
178+
OS << "_arg_" << V->getIdentifier()->getName();
179+
VarName = &Context.Idents.get(OS.str());
180+
181+
auto NewVarDecl = VarDecl::Create(
182+
Context, DC, SourceLocation(), SourceLocation(), VarName, ActualArgType,
183+
Context.getTrivialTypeSourceInfo(ActualArgType), SC_None);
184+
ArgTys.push_back(ActualArgType);
185+
NewArgDecls.push_back(NewVarDecl);
186+
}
187+
}
188+
189+
void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
190+
// TODO: Case when kernel is functor
191+
LambdaExpr *LE = getBodyAsLambda(e);
192+
if (LE) {
193+
194+
llvm::SmallVector<DeclaratorDecl *, 16> ArgDecls;
195+
196+
for (const auto &V : LE->captures()) {
197+
ArgDecls.push_back(V.getCapturedVar());
198+
}
199+
200+
llvm::SmallVector<QualType, 16> ArgTys;
201+
llvm::SmallVector<DeclaratorDecl *, 16> NewArgDecls;
202+
BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys);
203+
204+
// Get Name for our kernel.
205+
FunctionDecl *FuncDecl = e->getMethodDecl();
206+
const TemplateArgumentList *TemplateArgs =
207+
FuncDecl->getTemplateSpecializationArgs();
208+
QualType KernelNameType = TemplateArgs->get(0).getAsType();
209+
std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str();
210+
211+
if (const auto *RecordDecl = KernelNameType->getAsCXXRecordDecl()) {
212+
const auto *TemplateDecl =
213+
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
214+
if (TemplateDecl) {
215+
QualType ParamType = TemplateDecl->getTemplateArgs()[0].getAsType();
216+
Name += "_" + ParamType.getAsString() + "_";
217+
}
218+
}
219+
220+
FunctionDecl *SYCLKernel =
221+
CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls);
222+
223+
CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel);
224+
SYCLKernel->setBody(SYCLKernelBody);
225+
226+
AddSyclKernel(SYCLKernel);
227+
}
228+
}
229+

0 commit comments

Comments
 (0)