@@ -123,6 +123,10 @@ class Util {
123
123
// / specialization id class.
124
124
static bool isSyclSpecIdType (QualType Ty);
125
125
126
+ // / Checks whether given clang type is a full specialization of the SYCL
127
+ // / device_global class.
128
+ static bool isSyclDeviceGlobalType (QualType Ty);
129
+
126
130
// / Checks whether given clang type is a full specialization of the SYCL
127
131
// / kernel_handler class.
128
132
static bool isSyclKernelHandlerType (QualType Ty);
@@ -4692,7 +4696,23 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
4692
4696
O << " namespace sycl {\n " ;
4693
4697
O << " namespace detail {\n " ;
4694
4698
4695
- O << " \n " ;
4699
+ // Generate declaration of variable of type __sycl_device_global_registration
4700
+ // whose sole purpose is to run its constructor before the application's
4701
+ // main() function.
4702
+
4703
+ if (S.getSyclIntegrationFooter ().isDeviceGlobalsEmitted ()) {
4704
+ O << " namespace {\n " ;
4705
+
4706
+ O << " class __sycl_device_global_registration {\n " ;
4707
+ O << " public:\n " ;
4708
+ O << " __sycl_device_global_registration() noexcept;\n " ;
4709
+ O << " };\n " ;
4710
+ O << " __sycl_device_global_registration __sycl_device_global_registrar;\n " ;
4711
+
4712
+ O << " } // namespace\n " ;
4713
+
4714
+ O << " \n " ;
4715
+ }
4696
4716
4697
4717
O << " // names of all kernels defined in the corresponding source\n " ;
4698
4718
O << " static constexpr\n " ;
@@ -4874,9 +4894,9 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
4874
4894
// template instantiations as a VarDecl.
4875
4895
if (isa<VarTemplatePartialSpecializationDecl>(VD))
4876
4896
return ;
4877
- // Step 1: ensure that this is of the correct type-spec-constant template
4878
- // specialization).
4879
- if ( !Util::isSyclSpecIdType (VD->getType ())) {
4897
+ // Step 1: ensure that this is of the correct type template specialization.
4898
+ if (! Util::isSyclSpecIdType (VD-> getType ()) &&
4899
+ !Util::isSyclDeviceGlobalType (VD->getType ())) {
4880
4900
// Handle the case where this could be a deduced type, such as a deduction
4881
4901
// guide. We have to do this here since this function, unlike most of the
4882
4902
// rest of this file, is called during Sema instead of after it. We will
@@ -4892,8 +4912,8 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
4892
4912
// let an error happen during host compilation.
4893
4913
if (!VD->hasGlobalStorage () || VD->isLocalVarDeclOrParm ())
4894
4914
return ;
4895
- // Step 3: Add to SpecConstants collection.
4896
- SpecConstants .push_back (VD);
4915
+ // Step 3: Add to collection.
4916
+ GlobalVars .push_back (VD);
4897
4917
}
4898
4918
4899
4919
// Post-compile integration header support.
@@ -4967,29 +4987,28 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
4967
4987
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
4968
4988
}
4969
4989
4970
- static std::string EmitSpecIdShim (raw_ostream &OS, unsigned &ShimCounter,
4971
- const std::string &LastShim,
4972
- const NamespaceDecl *AnonNS) {
4990
+ static std::string EmitShim (raw_ostream &OS, unsigned &ShimCounter,
4991
+ const std::string &LastShim,
4992
+ const NamespaceDecl *AnonNS) {
4973
4993
std::string NewShimName =
4974
- " __sycl_detail::__spec_id_shim_ " + std::to_string (ShimCounter) + " ()" ;
4994
+ " __sycl_detail::__shim_ " + std::to_string (ShimCounter) + " ()" ;
4975
4995
// Print opening-namespace
4976
4996
PrintNamespaces (OS, Decl::castToDeclContext (AnonNS));
4977
4997
OS << " namespace __sycl_detail {\n " ;
4978
- OS << " static constexpr decltype(" << LastShim << " ) &__spec_id_shim_ "
4979
- << ShimCounter << " () {\n " ;
4998
+ OS << " static constexpr decltype(" << LastShim << " ) &__shim_ " << ShimCounter
4999
+ << " () {\n " ;
4980
5000
OS << " return " << LastShim << " ;\n " ;
4981
5001
OS << " }\n " ;
4982
- OS << " } // namespace __sycl_detail \n " ;
5002
+ OS << " } // namespace __sycl_detail\n " ;
4983
5003
PrintNSClosingBraces (OS, Decl::castToDeclContext (AnonNS));
4984
5004
4985
5005
++ShimCounter;
4986
5006
return NewShimName;
4987
5007
}
4988
5008
4989
5009
// Emit the list of shims required for a DeclContext, calls itself recursively.
4990
- static void EmitSpecIdShims (raw_ostream &OS, unsigned &ShimCounter,
4991
- const DeclContext *DC,
4992
- std::string &NameForLastShim) {
5010
+ static void EmitShims (raw_ostream &OS, unsigned &ShimCounter,
5011
+ const DeclContext *DC, std::string &NameForLastShim) {
4993
5012
if (DC->isTranslationUnit ()) {
4994
5013
NameForLastShim = " ::" + NameForLastShim;
4995
5014
return ;
@@ -5003,7 +5022,7 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
5003
5022
} else if (const auto *ND = dyn_cast<NamespaceDecl>(CurDecl)) {
5004
5023
if (ND->isAnonymousNamespace ()) {
5005
5024
// Print current shim, reset 'name for last shim'.
5006
- NameForLastShim = EmitSpecIdShim (OS, ShimCounter, NameForLastShim, ND);
5025
+ NameForLastShim = EmitShim (OS, ShimCounter, NameForLastShim, ND);
5007
5026
} else {
5008
5027
NameForLastShim = ND->getNameAsString () + " ::" + NameForLastShim;
5009
5028
}
@@ -5017,22 +5036,22 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
5017
5036
" Unhandled decl type" );
5018
5037
}
5019
5038
5020
- EmitSpecIdShims (OS, ShimCounter, CurDecl->getDeclContext (), NameForLastShim);
5039
+ EmitShims (OS, ShimCounter, CurDecl->getDeclContext (), NameForLastShim);
5021
5040
}
5022
5041
5023
5042
// Emit the list of shims required for a variable declaration.
5024
5043
// Returns a string containing the FQN of the 'top most' shim, including its
5025
5044
// function call parameters.
5026
- static std::string EmitSpecIdShims (raw_ostream &OS, unsigned &ShimCounter,
5027
- PrintingPolicy &Policy, const VarDecl *VD) {
5045
+ static std::string EmitShims (raw_ostream &OS, unsigned &ShimCounter,
5046
+ PrintingPolicy &Policy, const VarDecl *VD) {
5028
5047
if (!VD->isInAnonymousNamespace ())
5029
5048
return " " ;
5030
5049
std::string RelativeName;
5031
5050
llvm::raw_string_ostream stream (RelativeName);
5032
5051
VD->getNameForDiagnostic (stream, Policy, false );
5033
5052
stream.flush ();
5034
5053
5035
- EmitSpecIdShims (OS, ShimCounter, VD->getDeclContext (), RelativeName);
5054
+ EmitShims (OS, ShimCounter, VD->getDeclContext (), RelativeName);
5036
5055
return RelativeName;
5037
5056
}
5038
5057
@@ -5042,58 +5061,90 @@ bool SYCLIntegrationFooter::emit(raw_ostream &OS) {
5042
5061
Policy.SuppressTypedefs = true ;
5043
5062
Policy.SuppressUnwrittenScope = true ;
5044
5063
5045
- llvm::SmallSet<const VarDecl *, 8 > VisitedSpecConstants ;
5064
+ llvm::SmallSet<const VarDecl *, 8 > Visited ;
5046
5065
bool EmittedFirstSpecConstant = false ;
5047
5066
5048
5067
// Used to uniquely name the 'shim's as we generate the names in each
5049
5068
// anonymous namespace.
5050
5069
unsigned ShimCounter = 0 ;
5051
- for (const VarDecl *VD : SpecConstants) {
5070
+
5071
+ std::string DeviceGlobalsBuf;
5072
+ llvm::raw_string_ostream DeviceGlobOS (DeviceGlobalsBuf);
5073
+ for (const VarDecl *VD : GlobalVars) {
5052
5074
VD = VD->getCanonicalDecl ();
5053
5075
5054
- // Skip if this isn't a SpecIdType. This can happen if it was a deduced
5055
- // type.
5056
- if (!Util::isSyclSpecIdType (VD->getType ()))
5076
+ // Skip if this isn't a SpecIdType or DeviceGlobal. This can happen if it
5077
+ // was a deduced type.
5078
+ if (!Util::isSyclSpecIdType (VD->getType ()) &&
5079
+ !Util::isSyclDeviceGlobalType (VD->getType ()))
5057
5080
continue ;
5058
5081
5059
5082
// Skip if we've already visited this.
5060
- if (llvm::find (VisitedSpecConstants , VD) != VisitedSpecConstants .end ())
5083
+ if (llvm::find (Visited , VD) != Visited .end ())
5061
5084
continue ;
5062
5085
5063
- // We only want to emit the #includes if we have a spec-constant that needs
5086
+ // We only want to emit the #includes if we have a variable that needs
5064
5087
// them, so emit this one on the first time through the loop.
5065
- if (!EmittedFirstSpecConstant)
5088
+ if (!EmittedFirstSpecConstant && !DeviceGlobalsEmitted )
5066
5089
OS << " #include <CL/sycl/detail/defines_elementary.hpp>\n " ;
5067
- EmittedFirstSpecConstant = true ;
5068
-
5069
- VisitedSpecConstants.insert (VD);
5070
- std::string TopShim = EmitSpecIdShims (OS, ShimCounter, Policy, VD);
5071
- OS << " __SYCL_INLINE_NAMESPACE(cl) {\n " ;
5072
- OS << " namespace sycl {\n " ;
5073
- OS << " namespace detail {\n " ;
5074
- OS << " template<>\n " ;
5075
- OS << " inline const char *get_spec_constant_symbolic_ID_impl<" ;
5076
-
5077
- if (VD->isInAnonymousNamespace ()) {
5078
- OS << TopShim;
5090
+
5091
+ Visited.insert (VD);
5092
+ std::string TopShim = EmitShims (OS, ShimCounter, Policy, VD);
5093
+ if (Util::isSyclDeviceGlobalType (VD->getType ())) {
5094
+ DeviceGlobalsEmitted = true ;
5095
+ DeviceGlobOS << " device_global_map::add(" ;
5096
+ DeviceGlobOS << " (void *)&" ;
5097
+ if (VD->isInAnonymousNamespace ()) {
5098
+ DeviceGlobOS << TopShim;
5099
+ } else {
5100
+ DeviceGlobOS << " ::" ;
5101
+ VD->getNameForDiagnostic (DeviceGlobOS, Policy, true );
5102
+ }
5103
+ DeviceGlobOS << " , \" " ;
5104
+ DeviceGlobOS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (),
5105
+ VD);
5106
+ DeviceGlobOS << " \" );\n " ;
5079
5107
} else {
5080
- OS << " ::" ;
5081
- VD->getNameForDiagnostic (OS, Policy, true );
5082
- }
5108
+ EmittedFirstSpecConstant = true ;
5109
+ OS << " __SYCL_INLINE_NAMESPACE(cl) {\n " ;
5110
+ OS << " namespace sycl {\n " ;
5111
+ OS << " namespace detail {\n " ;
5112
+ OS << " template<>\n " ;
5113
+ OS << " inline const char *get_spec_constant_symbolic_ID_impl<" ;
5114
+
5115
+ if (VD->isInAnonymousNamespace ()) {
5116
+ OS << TopShim;
5117
+ } else {
5118
+ OS << " ::" ;
5119
+ VD->getNameForDiagnostic (OS, Policy, true );
5120
+ }
5083
5121
5084
- OS << " >() {\n " ;
5085
- OS << " return \" " ;
5086
- OS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (), VD);
5087
- OS << " \" ;\n " ;
5088
- OS << " }\n " ;
5089
- OS << " } // namespace detail\n " ;
5090
- OS << " } // namespace sycl\n " ;
5091
- OS << " } // __SYCL_INLINE_NAMESPACE(cl)\n " ;
5122
+ OS << " >() {\n " ;
5123
+ OS << " return \" " ;
5124
+ OS << SYCLUniqueStableIdExpr::ComputeName (S.getASTContext (), VD);
5125
+ OS << " \" ;\n " ;
5126
+ OS << " }\n " ;
5127
+ OS << " } // namespace detail\n " ;
5128
+ OS << " } // namespace sycl\n " ;
5129
+ OS << " } // __SYCL_INLINE_NAMESPACE(cl)\n " ;
5130
+ }
5092
5131
}
5093
5132
5094
5133
if (EmittedFirstSpecConstant)
5095
5134
OS << " #include <CL/sycl/detail/spec_const_integration.hpp>\n " ;
5096
5135
5136
+ if (DeviceGlobalsEmitted) {
5137
+ OS << " #include <CL/sycl/detail/device_global_map.hpp>\n " ;
5138
+ DeviceGlobOS.flush ();
5139
+ OS << " namespace sycl::detail {\n " ;
5140
+ OS << " namespace {\n " ;
5141
+ OS << " __sycl_device_global_registration::__sycl_device_global_"
5142
+ " registration() noexcept {\n " ;
5143
+ OS << DeviceGlobalsBuf;
5144
+ OS << " }\n " ;
5145
+ OS << " } // namespace (unnamed)\n " ;
5146
+ OS << " } // namespace sycl::detail\n " ;
5147
+ }
5097
5148
return true ;
5098
5149
}
5099
5150
@@ -5138,6 +5189,18 @@ bool Util::isSyclSpecIdType(QualType Ty) {
5138
5189
return matchQualifiedTypeName (Ty, Scopes);
5139
5190
}
5140
5191
5192
+ bool Util::isSyclDeviceGlobalType (QualType Ty) {
5193
+ const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl ();
5194
+ if (!RecTy)
5195
+ return false ;
5196
+ if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RecTy)) {
5197
+ ClassTemplateDecl *Template = CTSD->getSpecializedTemplate ();
5198
+ if (CXXRecordDecl *RD = Template->getTemplatedDecl ())
5199
+ return RD->hasAttr <SYCLDeviceGlobalAttr>();
5200
+ }
5201
+ return RecTy->hasAttr <SYCLDeviceGlobalAttr>();
5202
+ }
5203
+
5141
5204
bool Util::isSyclKernelHandlerType (QualType Ty) {
5142
5205
std::array<DeclContextDesc, 3 > Scopes = {
5143
5206
Util::MakeDeclContextDesc (Decl::Kind::Namespace, " cl" ),
0 commit comments