1414#include " clang/AST/QualTypeNames.h"
1515#include " clang/AST/RecordLayout.h"
1616#include " clang/AST/RecursiveASTVisitor.h"
17+ #include " clang/AST/TemplateArgumentVisitor.h"
18+ #include " clang/AST/TypeVisitor.h"
1719#include " clang/Analysis/CallGraph.h"
1820#include " clang/Basic/Attributes.h"
1921#include " clang/Basic/Builtins.h"
@@ -2473,9 +2475,111 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
24732475
24742476} // namespace
24752477
2478+ class SYCLTypeVisitor : public TypeVisitor <SYCLTypeVisitor>,
2479+ public ConstTemplateArgumentVisitor<SYCLTypeVisitor> {
2480+ Sema &S;
2481+ SourceLocation Loc;
2482+ using InnerTypeVisitor = TypeVisitor<SYCLTypeVisitor>;
2483+ using InnerTAVisitor = ConstTemplateArgumentVisitor<SYCLTypeVisitor>;
2484+
2485+ public:
2486+ SYCLTypeVisitor (Sema &S, SourceLocation Loc) : S(S), Loc(Loc) {}
2487+
2488+ void Visit (QualType T) {
2489+ if (T.isNull ())
2490+ return ;
2491+ const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
2492+ if (!RD)
2493+ return ;
2494+ // If KernelNameType has template args visit each template arg via
2495+ // ConstTemplateArgumentVisitor
2496+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
2497+ const TemplateArgumentList &Args = TSD->getTemplateArgs ();
2498+ for (unsigned I = 0 ; I < Args.size (); I++) {
2499+ const TemplateArgument &TemplateArg = Args[I];
2500+ Visit (TemplateArg);
2501+ }
2502+ } else {
2503+ InnerTypeVisitor::Visit (T.getTypePtr ());
2504+ }
2505+ }
2506+
2507+ void Visit (const TemplateArgument &TA) { InnerTAVisitor::Visit (TA); }
2508+
2509+ void VisitEnumType (const EnumType *T) {
2510+ const EnumDecl *ED = T->getDecl ();
2511+ if (!ED->isScoped () && !ED->isFixed ()) {
2512+ S.Diag (Loc, diag::err_sycl_kernel_incorrectly_named) << 2 ;
2513+ S.Diag (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2514+ << ED;
2515+ return ;
2516+ }
2517+ }
2518+
2519+ void VisitRecordType (const RecordType *T) {
2520+ return VisitTagDecl (T->getDecl ());
2521+ }
2522+
2523+ void VisitTagDecl (const TagDecl *Tag) {
2524+ bool UnnamedLambdaEnabled =
2525+ S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
2526+ if (Tag && !UnnamedLambdaEnabled) {
2527+ const bool KernelNameIsMissing = Tag->getName ().empty ();
2528+ if (KernelNameIsMissing) {
2529+ S.Diag (Loc, diag::err_sycl_kernel_incorrectly_named)
2530+ << /* kernel name is missing */ 0 ;
2531+ return ;
2532+ } else {
2533+ if (Tag->isCompleteDefinition ())
2534+ S.Diag (Loc, diag::err_sycl_kernel_incorrectly_named)
2535+ << /* kernel name is not globally-visible */ 1 ;
2536+ else
2537+ S.Diag (Loc, diag::warn_sycl_implicit_decl);
2538+
2539+ S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
2540+ << Tag->getName ();
2541+ return ;
2542+ }
2543+ }
2544+ }
2545+
2546+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
2547+ QualType T = TA.getAsType ();
2548+ if (const auto *ET = T->getAs <EnumType>()) {
2549+ VisitEnumType (ET);
2550+ } else {
2551+ Visit (T);
2552+ }
2553+ return ;
2554+ }
2555+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
2556+ QualType T = TA.getIntegralType ();
2557+ if (const EnumType *ET = T->getAs <EnumType>()) {
2558+ VisitEnumType (ET);
2559+ }
2560+ return ;
2561+ }
2562+
2563+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
2564+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
2565+ TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
2566+ for (NamedDecl *P : *TemplateParams) {
2567+ if (NonTypeTemplateParmDecl *TemplateParam =
2568+ dyn_cast<NonTypeTemplateParmDecl>(P)) {
2569+ QualType T = TemplateParam->getType ();
2570+ if (const EnumType *ET = T->getAs <EnumType>()) {
2571+ VisitEnumType (ET);
2572+ }
2573+ }
2574+ }
2575+ }
2576+ };
2577+
24762578void Sema::CheckSYCLKernelCall (FunctionDecl *KernelFunc, SourceRange CallLoc,
24772579 ArrayRef<const Expr *> Args) {
24782580 const CXXRecordDecl *KernelObj = getKernelObjectType (KernelFunc);
2581+ QualType KernelNameType =
2582+ calculateKernelNameType (getASTContext (), KernelFunc);
24792583 if (!KernelObj) {
24802584 Diag (Args[0 ]->getExprLoc (), diag::err_sycl_kernel_not_function_object);
24812585 KernelFunc->setInvalidDecl ();
@@ -2511,6 +2615,8 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
25112615 return ;
25122616
25132617 KernelObjVisitor Visitor{*this };
2618+ SYCLTypeVisitor KernelTypeVisitor (*this , Args[0 ]->getExprLoc ());
2619+ (void )KernelTypeVisitor.Visit (KernelNameType);
25142620 DiagnosingSYCLKernel = true ;
25152621 Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
25162622 ArgsSizeChecker);
@@ -2856,18 +2962,6 @@ static void emitWithoutAnonNamespaces(llvm::raw_ostream &OS, StringRef Source) {
28562962 OS << Source;
28572963}
28582964
2859- static bool checkEnumTemplateParameter (const EnumDecl *ED,
2860- DiagnosticsEngine &Diag,
2861- SourceLocation KernelLocation) {
2862- if (!ED->isScoped () && !ED->isFixed ()) {
2863- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named) << 2 ;
2864- Diag.Report (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2865- << ED;
2866- return true ;
2867- }
2868- return false ;
2869- }
2870-
28712965// Emits a forward declaration
28722966void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
28732967 SourceLocation KernelLocation) {
@@ -2880,32 +2974,6 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
28802974 auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
28812975
28822976 if (!NS) {
2883- if (!DC->isTranslationUnit ()) {
2884- const TagDecl *TD = isa<ClassTemplateDecl>(D)
2885- ? cast<ClassTemplateDecl>(D)->getTemplatedDecl ()
2886- : dyn_cast<TagDecl>(D);
2887-
2888- if (TD && !UnnamedLambdaSupport) {
2889- // defined class constituting the kernel name is not globally
2890- // accessible - contradicts the spec
2891- const bool KernelNameIsMissing = TD->getName ().empty ();
2892- if (KernelNameIsMissing) {
2893- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named)
2894- << /* kernel name is missing */ 0 ;
2895- // Don't emit note if kernel name was completely omitted
2896- } else {
2897- if (TD->isCompleteDefinition ())
2898- Diag.Report (KernelLocation,
2899- diag::err_sycl_kernel_incorrectly_named)
2900- << /* kernel name is not globally-visible */ 1 ;
2901- else
2902- Diag.Report (KernelLocation, diag::warn_sycl_implicit_decl);
2903- Diag.Report (D->getSourceRange ().getBegin (),
2904- diag::note_previous_decl)
2905- << TD->getName ();
2906- }
2907- }
2908- }
29092977 break ;
29102978 }
29112979 ++NamespaceCnt;
@@ -3013,7 +3081,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
30133081 // Handle Kernel Name Type templated using enum type and value.
30143082 if (const auto *ET = T->getAs <EnumType>()) {
30153083 const EnumDecl *ED = ET->getDecl ();
3016- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
30173084 emitFwdDecl (O, ED, KernelLocation);
30183085 } else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
30193086 emitForwardClassDecls (O, T, KernelLocation, Printed);
@@ -3073,7 +3140,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
30733140 QualType T = TemplateParam->getType ();
30743141 if (const auto *ET = T->getAs <EnumType>()) {
30753142 const EnumDecl *ED = ET->getDecl ();
3076- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
30773143 emitFwdDecl (O, ED, KernelLocation);
30783144 }
30793145 }
0 commit comments