@@ -49,27 +49,113 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
4949 MarkDeviceFunction (Sema &S)
5050 : RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
5151 bool VisitCallExpr (CallExpr *e) {
52+ for (const auto &Arg : e->arguments ())
53+ CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
54+
5255 if (FunctionDecl *Callee = e->getDirectCallee ()) {
5356 // Remember that all SYCL kernel functions have deferred
5457 // instantiation as template functions. It means that
5558 // all functions used by kernel have already been parsed and have
5659 // definitions.
60+
61+ CheckTypeForVirtual (Callee->getReturnType (), Callee->getSourceRange ());
62+
5763 if (FunctionDecl *Def = Callee->getDefinition ()) {
5864 if (!Def->hasAttr <SYCLDeviceAttr>()) {
5965 Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
6066 this ->TraverseStmt (Def->getBody ());
61- // But because parser works with top level declarations and CodeGen
62- // already saw and ignored our function without device attribute we
63- // need to add this function into SYCL kernels array to show it
64- // this function again.
6567 SemaRef.AddSyclKernel (Def);
6668 }
6769 }
6870 }
6971 return true ;
7072 }
7173
74+ bool VisitCXXConstructExpr (CXXConstructExpr *E) {
75+ for (const auto &Arg : E->arguments ())
76+ CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
77+
78+ CXXConstructorDecl *Ctor = E->getConstructor ();
79+
80+ if (FunctionDecl *Def = Ctor->getDefinition ()) {
81+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
82+ this ->TraverseStmt (Def->getBody ());
83+ SemaRef.AddSyclKernel (Def);
84+ }
85+
86+ const auto *ConstructedType = Ctor->getParent ();
87+ if (ConstructedType->hasUserDeclaredDestructor ()) {
88+ CXXDestructorDecl *Dtor = ConstructedType->getDestructor ();
89+
90+ if (FunctionDecl *Def = Dtor->getDefinition ()) {
91+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
92+ this ->TraverseStmt (Def->getBody ());
93+ SemaRef.AddSyclKernel (Def);
94+ }
95+ }
96+ return true ;
97+ }
98+
99+ bool VisitTypedefNameDecl (TypedefNameDecl *TD) {
100+ CheckTypeForVirtual (TD->getUnderlyingType (), TD->getLocation ());
101+ return true ;
102+ }
103+
104+ bool VisitRecordDecl (RecordDecl *RD) {
105+ CheckTypeForVirtual (QualType{RD->getTypeForDecl (), 0 }, RD->getLocation ());
106+ return true ;
107+ }
108+
109+ bool VisitParmVarDecl (VarDecl *VD) {
110+ CheckTypeForVirtual (VD->getType (), VD->getLocation ());
111+ return true ;
112+ }
113+
114+ bool VisitVarDecl (VarDecl *VD) {
115+ CheckTypeForVirtual (VD->getType (), VD->getLocation ());
116+ return true ;
117+ }
118+
119+ bool VisitDeclRefExpr (DeclRefExpr *E) {
120+ CheckTypeForVirtual (E->getType (), E->getSourceRange ());
121+ return true ;
122+ }
123+
72124private:
125+ bool CheckTypeForVirtual (QualType Ty, SourceRange Loc) {
126+ while (Ty->isAnyPointerType () || Ty->isArrayType ())
127+ Ty = QualType{Ty->getPointeeOrArrayElementType (), 0 };
128+
129+ if (const auto *CRD = Ty->getAsCXXRecordDecl ()) {
130+ if (CRD->isPolymorphic ()) {
131+ SemaRef.Diag (CRD->getLocation (), diag::err_sycl_virtual_types);
132+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
133+ return false ;
134+ }
135+
136+ for (const auto &Field : CRD->fields ()) {
137+ if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
138+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
139+ return false ;
140+ }
141+ }
142+ } else if (const auto *RD = Ty->getAsRecordDecl ()) {
143+ for (const auto &Field : RD->fields ()) {
144+ if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
145+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
146+ return false ;
147+ }
148+ }
149+ } else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
150+ for (const auto &ParamTy : FPTy->param_types ())
151+ if (!CheckTypeForVirtual (ParamTy, Loc))
152+ return false ;
153+ return CheckTypeForVirtual (FPTy->getReturnType (), Loc);
154+ } else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
155+ return CheckTypeForVirtual (FTy->getReturnType (), Loc);
156+ }
157+ return true ;
158+ }
73159 Sema &SemaRef;
74160};
75161
0 commit comments