|
| 1 | +//===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===// |
| 2 | +// |
| 3 | +// The LLVM Compiler Infrastructure |
| 4 | +// |
| 5 | +// This file is distributed under the University of Illinois Open Source |
| 6 | +// License. See LICENSE.TXT for details. |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | +// This implements Semantic Analysis for SYCL constructs. |
| 10 | +//===----------------------------------------------------------------------===// |
| 11 | + |
| 12 | +#include "clang/AST/AST.h" |
| 13 | +#include "clang/Sema/Sema.h" |
| 14 | +#include "llvm/ADT/SmallVector.h" |
| 15 | + |
| 16 | +using namespace clang; |
| 17 | + |
| 18 | +LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) { |
| 19 | + auto LastArg = e->getArg(e->getNumArgs() - 1); |
| 20 | + return dyn_cast<LambdaExpr>(LastArg); |
| 21 | +} |
| 22 | + |
| 23 | +FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name, |
| 24 | + ArrayRef<QualType> ArgTys, |
| 25 | + ArrayRef<DeclaratorDecl *> ArgDecls) { |
| 26 | + |
| 27 | + DeclContext *DC = Context.getTranslationUnitDecl(); |
| 28 | + FunctionProtoType::ExtProtoInfo Info; |
| 29 | + QualType RetTy = Context.VoidTy; |
| 30 | + QualType FuncTy = Context.getFunctionType(RetTy, ArgTys, Info); |
| 31 | + DeclarationName DN = DeclarationName(&Context.Idents.get(Name)); |
| 32 | + FunctionDecl *Result = FunctionDecl::Create( |
| 33 | + Context, DC, SourceLocation(), SourceLocation(), DN, FuncTy, |
| 34 | + Context.getTrivialTypeSourceInfo(RetTy), SC_None); |
| 35 | + llvm::SmallVector<ParmVarDecl *, 16> Params; |
| 36 | + int i = 0; |
| 37 | + for (auto ArgTy : ArgTys) { |
| 38 | + auto P = |
| 39 | + ParmVarDecl::Create(Context, Result, SourceLocation(), SourceLocation(), |
| 40 | + ArgDecls[i]->getIdentifier(), ArgTy, |
| 41 | + ArgDecls[i]->getTypeSourceInfo(), SC_None, 0); |
| 42 | + P->setScopeInfo(0, i++); |
| 43 | + P->setIsUsed(); |
| 44 | + Params.push_back(P); |
| 45 | + } |
| 46 | + Result->setParams(Params); |
| 47 | + // TODO: Add SYCL specific attribute for kernel and all functions called |
| 48 | + // by kernel. |
| 49 | + Result->addAttr(OpenCLKernelAttr::CreateImplicit(Context)); |
| 50 | + Result->addAttr(AsmLabelAttr::CreateImplicit(Context, Name)); |
| 51 | + return Result; |
| 52 | +} |
| 53 | + |
| 54 | +CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e, |
| 55 | + DeclContext *DC) { |
| 56 | + |
| 57 | + llvm::SmallVector<Stmt *, 16> BodyStmts; |
| 58 | + |
| 59 | + // TODO: case when kernel is functor |
| 60 | + // TODO: possible refactoring when functor case will be completed |
| 61 | + LambdaExpr *LE = getBodyAsLambda(e); |
| 62 | + if (LE) { |
| 63 | + // Create Lambda object |
| 64 | + CXXRecordDecl *LC = LE->getLambdaClass(); |
| 65 | + auto Lambda_VD = VarDecl::Create( |
| 66 | + S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(), |
| 67 | + QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None); |
| 68 | + Stmt *DS = new (S.Context) |
| 69 | + DeclStmt(DeclGroupRef(Lambda_VD), SourceLocation(), SourceLocation()); |
| 70 | + BodyStmts.push_back(DS); |
| 71 | + auto Lambda_DRE = DeclRefExpr::Create( |
| 72 | + S.Context, NestedNameSpecifierLoc(), SourceLocation(), Lambda_VD, false, |
| 73 | + DeclarationNameInfo(), QualType(LC->getTypeForDecl(), 0), VK_LValue); |
| 74 | + |
| 75 | + // Init Lambda fields |
| 76 | + llvm::SmallVector<Expr *, 16> InitCaptures; |
| 77 | + |
| 78 | + auto TargetFunc = dyn_cast<FunctionDecl>(DC); |
| 79 | + auto TargetFuncParam = |
| 80 | + TargetFunc->param_begin(); // Iterator to ParamVarDecl (VarDecl) |
| 81 | + for (auto CaptureField : LE->captures()) { |
| 82 | + VarDecl *CapturedVar = |
| 83 | + CaptureField |
| 84 | + .getCapturedVar(); // accessor, need to do setInit for this |
| 85 | + QualType ParamType = (*TargetFuncParam)->getOriginalType(); |
| 86 | + auto DRE = DeclRefExpr::Create( |
| 87 | + S.Context, NestedNameSpecifierLoc(), SourceLocation(), |
| 88 | + *TargetFuncParam, false, DeclarationNameInfo(), ParamType, VK_LValue); |
| 89 | + |
| 90 | + Expr *Res = ImplicitCastExpr::Create( |
| 91 | + S.Context, ParamType, CK_LValueToRValue, DRE, nullptr, VK_RValue); |
| 92 | + |
| 93 | + Expr *InitCapture = new (S.Context) InitListExpr( |
| 94 | + S.Context, SourceLocation(), /*initExprs*/ Res, SourceLocation()); |
| 95 | + CapturedVar->setInit(InitCapture); |
| 96 | + InitCapture->setType(CapturedVar->getType()); |
| 97 | + InitCaptures.push_back(InitCapture); |
| 98 | + TargetFuncParam++; |
| 99 | + } |
| 100 | + |
| 101 | + Expr *InitLambdaCaptures = new (S.Context) |
| 102 | + InitListExpr(S.Context, SourceLocation(), /*initExprs*/ InitCaptures, |
| 103 | + SourceLocation()); |
| 104 | + InitLambdaCaptures->setType(Lambda_VD->getType()); |
| 105 | + Lambda_VD->setInit(InitLambdaCaptures); |
| 106 | + |
| 107 | + // Create Lambda operator () call |
| 108 | + FunctionDecl *LO = LE->getCallOperator(); |
| 109 | + ArrayRef<ParmVarDecl *> Args = LO->parameters(); |
| 110 | + llvm::SmallVector<Expr *, 16> ParamStmts(1); |
| 111 | + ParamStmts[0] = dyn_cast<Expr>(Lambda_DRE); |
| 112 | + |
| 113 | + // Collect arguments for () operator |
| 114 | + for (auto Arg : Args) { |
| 115 | + QualType ArgType = Arg->getOriginalType(); |
| 116 | + // Declare variable for parameter and pass it to call |
| 117 | + auto param_VD = |
| 118 | + VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(), |
| 119 | + Arg->getIdentifier(), ArgType, |
| 120 | + S.Context.getTrivialTypeSourceInfo(ArgType), SC_None); |
| 121 | + Stmt *param_DS = new (S.Context) |
| 122 | + DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation()); |
| 123 | + BodyStmts.push_back(param_DS); |
| 124 | + auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(), |
| 125 | + SourceLocation(), param_VD, false, |
| 126 | + DeclarationNameInfo(), ArgType, VK_LValue); |
| 127 | + ParamStmts.push_back(DRE); |
| 128 | + } |
| 129 | + |
| 130 | + // Create ref for call operator |
| 131 | + DeclRefExpr *DRE = new (S.Context) |
| 132 | + DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue, |
| 133 | + SourceLocation()); |
| 134 | + QualType ResultTy = LO->getReturnType(); |
| 135 | + ExprValueKind VK = Expr::getValueKindForType(ResultTy); |
| 136 | + ResultTy = ResultTy.getNonLValueExprType(S.Context); |
| 137 | + |
| 138 | + CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create( |
| 139 | + S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(), |
| 140 | + FPOptions(), clang::CallExpr::ADLCallKind::NotADL ); |
| 141 | + BodyStmts.push_back(TheCall); |
| 142 | + } |
| 143 | + return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(), |
| 144 | + SourceLocation()); |
| 145 | +} |
| 146 | + |
| 147 | +void BuildArgTys(ASTContext &Context, |
| 148 | + llvm::SmallVector<DeclaratorDecl *, 16> &ArgDecls, |
| 149 | + llvm::SmallVector<DeclaratorDecl *, 16> &NewArgDecls, |
| 150 | + llvm::SmallVector<QualType, 16> &ArgTys) { |
| 151 | + for (auto V : ArgDecls) { |
| 152 | + QualType ArgTy = V->getType(); |
| 153 | + QualType ActualArgType = ArgTy; |
| 154 | + StringRef Name = ArgTy.getBaseTypeIdentifier()->getName(); |
| 155 | + // TODO: harden this check with additional validation that this class is |
| 156 | + // declared in cl::sycl namespace |
| 157 | + if (std::string(Name) == "accessor") { |
| 158 | + if (const auto *RecordDecl = ArgTy->getAsCXXRecordDecl()) { |
| 159 | + const auto *TemplateDecl = |
| 160 | + dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl); |
| 161 | + if (TemplateDecl) { |
| 162 | + QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); |
| 163 | + Qualifiers Quals = PointeeType.getQualifiers(); |
| 164 | + Quals.setAddressSpace(LangAS::opencl_global); |
| 165 | + PointeeType = |
| 166 | + Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals); |
| 167 | + QualType PointerType = Context.getPointerType(PointeeType); |
| 168 | + ActualArgType = |
| 169 | + Context.getQualifiedType(PointerType.getUnqualifiedType(), Quals); |
| 170 | + } |
| 171 | + } |
| 172 | + } |
| 173 | + DeclContext *DC = Context.getTranslationUnitDecl(); |
| 174 | + |
| 175 | + IdentifierInfo *VarName = 0; |
| 176 | + SmallString<8> Str; |
| 177 | + llvm::raw_svector_ostream OS(Str); |
| 178 | + OS << "_arg_" << V->getIdentifier()->getName(); |
| 179 | + VarName = &Context.Idents.get(OS.str()); |
| 180 | + |
| 181 | + auto NewVarDecl = VarDecl::Create( |
| 182 | + Context, DC, SourceLocation(), SourceLocation(), VarName, ActualArgType, |
| 183 | + Context.getTrivialTypeSourceInfo(ActualArgType), SC_None); |
| 184 | + ArgTys.push_back(ActualArgType); |
| 185 | + NewArgDecls.push_back(NewVarDecl); |
| 186 | + } |
| 187 | +} |
| 188 | + |
| 189 | +void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) { |
| 190 | + // TODO: Case when kernel is functor |
| 191 | + LambdaExpr *LE = getBodyAsLambda(e); |
| 192 | + if (LE) { |
| 193 | + |
| 194 | + llvm::SmallVector<DeclaratorDecl *, 16> ArgDecls; |
| 195 | + |
| 196 | + for (const auto &V : LE->captures()) { |
| 197 | + ArgDecls.push_back(V.getCapturedVar()); |
| 198 | + } |
| 199 | + |
| 200 | + llvm::SmallVector<QualType, 16> ArgTys; |
| 201 | + llvm::SmallVector<DeclaratorDecl *, 16> NewArgDecls; |
| 202 | + BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys); |
| 203 | + |
| 204 | + // Get Name for our kernel. |
| 205 | + FunctionDecl *FuncDecl = e->getMethodDecl(); |
| 206 | + const TemplateArgumentList *TemplateArgs = |
| 207 | + FuncDecl->getTemplateSpecializationArgs(); |
| 208 | + QualType KernelNameType = TemplateArgs->get(0).getAsType(); |
| 209 | + std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str(); |
| 210 | + |
| 211 | + if (const auto *RecordDecl = KernelNameType->getAsCXXRecordDecl()) { |
| 212 | + const auto *TemplateDecl = |
| 213 | + dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl); |
| 214 | + if (TemplateDecl) { |
| 215 | + QualType ParamType = TemplateDecl->getTemplateArgs()[0].getAsType(); |
| 216 | + Name += "_" + ParamType.getAsString() + "_"; |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + FunctionDecl *SYCLKernel = |
| 221 | + CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls); |
| 222 | + |
| 223 | + CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel); |
| 224 | + SYCLKernel->setBody(SYCLKernelBody); |
| 225 | + |
| 226 | + AddSyclKernel(SYCLKernel); |
| 227 | + } |
| 228 | +} |
| 229 | + |
0 commit comments