Skip to content

Commit a9ad3af

Browse files
authored
[SYCL] Emit integration footer and header for device_global variables (#5576)
The implementation is based on the design doc https://github.com/intel/llvm/blob/sycl/sycl/doc/design/DeviceGlobal.md
1 parent dda743a commit a9ad3af

11 files changed

+552
-91
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,15 @@ def SYCLUsesAspects : InheritableAttr {
12611261
let Documentation = [Undocumented];
12621262
}
12631263

1264+
def SYCLDeviceGlobal : InheritableAttr {
1265+
let Spellings = [CXX11<"__sycl_detail__", "device_global">];
1266+
let Subjects = SubjectList<[CXXRecord], ErrorDiag>;
1267+
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
1268+
// Only used internally by the SYCL implementation
1269+
let Documentation = [Undocumented];
1270+
let SimpleHandler = 1;
1271+
}
1272+
12641273
// Marks functions which must not be vectorized via horizontal SIMT widening,
12651274
// e.g. because the function is already vectorized. Used to mark SYCL
12661275
// explicit SIMD kernels and functions.

clang/include/clang/Sema/Sema.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,14 @@ class SYCLIntegrationFooter {
440440
SYCLIntegrationFooter(Sema &S) : S(S) {}
441441
bool emit(StringRef MainSrc);
442442
void addVarDecl(const VarDecl *VD);
443+
bool isDeviceGlobalsEmitted() { return DeviceGlobalsEmitted; }
443444

444445
private:
445446
bool emit(raw_ostream &O);
446447
Sema &S;
447-
llvm::SmallVector<const VarDecl *> SpecConstants;
448+
llvm::SmallVector<const VarDecl *> GlobalVars;
448449
void emitSpecIDName(raw_ostream &O, const VarDecl *VD);
450+
bool DeviceGlobalsEmitted = false;
449451
};
450452

451453
/// Tracks expected type during expression parsing, for use in code completion.

clang/lib/Sema/Sema.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,11 +1110,14 @@ void Sema::ActOnEndOfTranslationUnitFragment(TUFragmentKind Kind) {
11101110
// Set the names of the kernels, now that the names have settled down. This
11111111
// needs to happen before we generate the integration headers.
11121112
SetSYCLKernelNames();
1113+
// Make sure that the footer is emitted before header, since only after the
1114+
// footer is emitted is it known that translation unit contains device
1115+
// global variables.
1116+
if (SyclIntFooter != nullptr)
1117+
SyclIntFooter->emit(getLangOpts().SYCLIntFooter);
11131118
// Emit SYCL integration header for current translation unit if needed
11141119
if (SyclIntHeader != nullptr)
11151120
SyclIntHeader->emit(getLangOpts().SYCLIntHeader);
1116-
if (SyclIntFooter != nullptr)
1117-
SyclIntFooter->emit(getLangOpts().SYCLIntFooter);
11181121
MarkDevices();
11191122
}
11201123

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 115 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class Util {
123123
/// specialization id class.
124124
static bool isSyclSpecIdType(QualType Ty);
125125

126+
/// Checks whether given clang type is a full specialization of the SYCL
127+
/// device_global class.
128+
static bool isSyclDeviceGlobalType(QualType Ty);
129+
126130
/// Checks whether given clang type is a full specialization of the SYCL
127131
/// kernel_handler class.
128132
static bool isSyclKernelHandlerType(QualType Ty);
@@ -4692,7 +4696,23 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
46924696
O << "namespace sycl {\n";
46934697
O << "namespace detail {\n";
46944698

4695-
O << "\n";
4699+
// Generate declaration of variable of type __sycl_device_global_registration
4700+
// whose sole purpose is to run its constructor before the application's
4701+
// main() function.
4702+
4703+
if (S.getSyclIntegrationFooter().isDeviceGlobalsEmitted()) {
4704+
O << "namespace {\n";
4705+
4706+
O << "class __sycl_device_global_registration {\n";
4707+
O << "public:\n";
4708+
O << " __sycl_device_global_registration() noexcept;\n";
4709+
O << "};\n";
4710+
O << "__sycl_device_global_registration __sycl_device_global_registrar;\n";
4711+
4712+
O << "} // namespace\n";
4713+
4714+
O << "\n";
4715+
}
46964716

46974717
O << "// names of all kernels defined in the corresponding source\n";
46984718
O << "static constexpr\n";
@@ -4874,9 +4894,9 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
48744894
// template instantiations as a VarDecl.
48754895
if (isa<VarTemplatePartialSpecializationDecl>(VD))
48764896
return;
4877-
// Step 1: ensure that this is of the correct type-spec-constant template
4878-
// specialization).
4879-
if (!Util::isSyclSpecIdType(VD->getType())) {
4897+
// Step 1: ensure that this is of the correct type template specialization.
4898+
if (!Util::isSyclSpecIdType(VD->getType()) &&
4899+
!Util::isSyclDeviceGlobalType(VD->getType())) {
48804900
// Handle the case where this could be a deduced type, such as a deduction
48814901
// guide. We have to do this here since this function, unlike most of the
48824902
// rest of this file, is called during Sema instead of after it. We will
@@ -4892,8 +4912,8 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
48924912
// let an error happen during host compilation.
48934913
if (!VD->hasGlobalStorage() || VD->isLocalVarDeclOrParm())
48944914
return;
4895-
// Step 3: Add to SpecConstants collection.
4896-
SpecConstants.push_back(VD);
4915+
// Step 3: Add to collection.
4916+
GlobalVars.push_back(VD);
48974917
}
48984918

48994919
// Post-compile integration header support.
@@ -4967,29 +4987,28 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
49674987
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
49684988
}
49694989

4970-
static std::string EmitSpecIdShim(raw_ostream &OS, unsigned &ShimCounter,
4971-
const std::string &LastShim,
4972-
const NamespaceDecl *AnonNS) {
4990+
static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter,
4991+
const std::string &LastShim,
4992+
const NamespaceDecl *AnonNS) {
49734993
std::string NewShimName =
4974-
"__sycl_detail::__spec_id_shim_" + std::to_string(ShimCounter) + "()";
4994+
"__sycl_detail::__shim_" + std::to_string(ShimCounter) + "()";
49754995
// Print opening-namespace
49764996
PrintNamespaces(OS, Decl::castToDeclContext(AnonNS));
49774997
OS << "namespace __sycl_detail {\n";
4978-
OS << "static constexpr decltype(" << LastShim << ") &__spec_id_shim_"
4979-
<< ShimCounter << "() {\n";
4998+
OS << "static constexpr decltype(" << LastShim << ") &__shim_" << ShimCounter
4999+
<< "() {\n";
49805000
OS << " return " << LastShim << ";\n";
49815001
OS << "}\n";
4982-
OS << "} // namespace __sycl_detail \n";
5002+
OS << "} // namespace __sycl_detail\n";
49835003
PrintNSClosingBraces(OS, Decl::castToDeclContext(AnonNS));
49845004

49855005
++ShimCounter;
49865006
return NewShimName;
49875007
}
49885008

49895009
// Emit the list of shims required for a DeclContext, calls itself recursively.
4990-
static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
4991-
const DeclContext *DC,
4992-
std::string &NameForLastShim) {
5010+
static void EmitShims(raw_ostream &OS, unsigned &ShimCounter,
5011+
const DeclContext *DC, std::string &NameForLastShim) {
49935012
if (DC->isTranslationUnit()) {
49945013
NameForLastShim = "::" + NameForLastShim;
49955014
return;
@@ -5003,7 +5022,7 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
50035022
} else if (const auto *ND = dyn_cast<NamespaceDecl>(CurDecl)) {
50045023
if (ND->isAnonymousNamespace()) {
50055024
// Print current shim, reset 'name for last shim'.
5006-
NameForLastShim = EmitSpecIdShim(OS, ShimCounter, NameForLastShim, ND);
5025+
NameForLastShim = EmitShim(OS, ShimCounter, NameForLastShim, ND);
50075026
} else {
50085027
NameForLastShim = ND->getNameAsString() + "::" + NameForLastShim;
50095028
}
@@ -5017,22 +5036,22 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
50175036
"Unhandled decl type");
50185037
}
50195038

5020-
EmitSpecIdShims(OS, ShimCounter, CurDecl->getDeclContext(), NameForLastShim);
5039+
EmitShims(OS, ShimCounter, CurDecl->getDeclContext(), NameForLastShim);
50215040
}
50225041

50235042
// Emit the list of shims required for a variable declaration.
50245043
// Returns a string containing the FQN of the 'top most' shim, including its
50255044
// function call parameters.
5026-
static std::string EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
5027-
PrintingPolicy &Policy, const VarDecl *VD) {
5045+
static std::string EmitShims(raw_ostream &OS, unsigned &ShimCounter,
5046+
PrintingPolicy &Policy, const VarDecl *VD) {
50285047
if (!VD->isInAnonymousNamespace())
50295048
return "";
50305049
std::string RelativeName;
50315050
llvm::raw_string_ostream stream(RelativeName);
50325051
VD->getNameForDiagnostic(stream, Policy, false);
50335052
stream.flush();
50345053

5035-
EmitSpecIdShims(OS, ShimCounter, VD->getDeclContext(), RelativeName);
5054+
EmitShims(OS, ShimCounter, VD->getDeclContext(), RelativeName);
50365055
return RelativeName;
50375056
}
50385057

@@ -5042,58 +5061,90 @@ bool SYCLIntegrationFooter::emit(raw_ostream &OS) {
50425061
Policy.SuppressTypedefs = true;
50435062
Policy.SuppressUnwrittenScope = true;
50445063

5045-
llvm::SmallSet<const VarDecl *, 8> VisitedSpecConstants;
5064+
llvm::SmallSet<const VarDecl *, 8> Visited;
50465065
bool EmittedFirstSpecConstant = false;
50475066

50485067
// Used to uniquely name the 'shim's as we generate the names in each
50495068
// anonymous namespace.
50505069
unsigned ShimCounter = 0;
5051-
for (const VarDecl *VD : SpecConstants) {
5070+
5071+
std::string DeviceGlobalsBuf;
5072+
llvm::raw_string_ostream DeviceGlobOS(DeviceGlobalsBuf);
5073+
for (const VarDecl *VD : GlobalVars) {
50525074
VD = VD->getCanonicalDecl();
50535075

5054-
// Skip if this isn't a SpecIdType. This can happen if it was a deduced
5055-
// type.
5056-
if (!Util::isSyclSpecIdType(VD->getType()))
5076+
// Skip if this isn't a SpecIdType or DeviceGlobal. This can happen if it
5077+
// was a deduced type.
5078+
if (!Util::isSyclSpecIdType(VD->getType()) &&
5079+
!Util::isSyclDeviceGlobalType(VD->getType()))
50575080
continue;
50585081

50595082
// Skip if we've already visited this.
5060-
if (llvm::find(VisitedSpecConstants, VD) != VisitedSpecConstants.end())
5083+
if (llvm::find(Visited, VD) != Visited.end())
50615084
continue;
50625085

5063-
// We only want to emit the #includes if we have a spec-constant that needs
5086+
// We only want to emit the #includes if we have a variable that needs
50645087
// them, so emit this one on the first time through the loop.
5065-
if (!EmittedFirstSpecConstant)
5088+
if (!EmittedFirstSpecConstant && !DeviceGlobalsEmitted)
50665089
OS << "#include <CL/sycl/detail/defines_elementary.hpp>\n";
5067-
EmittedFirstSpecConstant = true;
5068-
5069-
VisitedSpecConstants.insert(VD);
5070-
std::string TopShim = EmitSpecIdShims(OS, ShimCounter, Policy, VD);
5071-
OS << "__SYCL_INLINE_NAMESPACE(cl) {\n";
5072-
OS << "namespace sycl {\n";
5073-
OS << "namespace detail {\n";
5074-
OS << "template<>\n";
5075-
OS << "inline const char *get_spec_constant_symbolic_ID_impl<";
5076-
5077-
if (VD->isInAnonymousNamespace()) {
5078-
OS << TopShim;
5090+
5091+
Visited.insert(VD);
5092+
std::string TopShim = EmitShims(OS, ShimCounter, Policy, VD);
5093+
if (Util::isSyclDeviceGlobalType(VD->getType())) {
5094+
DeviceGlobalsEmitted = true;
5095+
DeviceGlobOS << "device_global_map::add(";
5096+
DeviceGlobOS << "(void *)&";
5097+
if (VD->isInAnonymousNamespace()) {
5098+
DeviceGlobOS << TopShim;
5099+
} else {
5100+
DeviceGlobOS << "::";
5101+
VD->getNameForDiagnostic(DeviceGlobOS, Policy, true);
5102+
}
5103+
DeviceGlobOS << ", \"";
5104+
DeviceGlobOS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(),
5105+
VD);
5106+
DeviceGlobOS << "\");\n";
50795107
} else {
5080-
OS << "::";
5081-
VD->getNameForDiagnostic(OS, Policy, true);
5082-
}
5108+
EmittedFirstSpecConstant = true;
5109+
OS << "__SYCL_INLINE_NAMESPACE(cl) {\n";
5110+
OS << "namespace sycl {\n";
5111+
OS << "namespace detail {\n";
5112+
OS << "template<>\n";
5113+
OS << "inline const char *get_spec_constant_symbolic_ID_impl<";
5114+
5115+
if (VD->isInAnonymousNamespace()) {
5116+
OS << TopShim;
5117+
} else {
5118+
OS << "::";
5119+
VD->getNameForDiagnostic(OS, Policy, true);
5120+
}
50835121

5084-
OS << ">() {\n";
5085-
OS << " return \"";
5086-
OS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(), VD);
5087-
OS << "\";\n";
5088-
OS << "}\n";
5089-
OS << "} // namespace detail\n";
5090-
OS << "} // namespace sycl\n";
5091-
OS << "} // __SYCL_INLINE_NAMESPACE(cl)\n";
5122+
OS << ">() {\n";
5123+
OS << " return \"";
5124+
OS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(), VD);
5125+
OS << "\";\n";
5126+
OS << "}\n";
5127+
OS << "} // namespace detail\n";
5128+
OS << "} // namespace sycl\n";
5129+
OS << "} // __SYCL_INLINE_NAMESPACE(cl)\n";
5130+
}
50925131
}
50935132

50945133
if (EmittedFirstSpecConstant)
50955134
OS << "#include <CL/sycl/detail/spec_const_integration.hpp>\n";
50965135

5136+
if (DeviceGlobalsEmitted) {
5137+
OS << "#include <CL/sycl/detail/device_global_map.hpp>\n";
5138+
DeviceGlobOS.flush();
5139+
OS << "namespace sycl::detail {\n";
5140+
OS << "namespace {\n";
5141+
OS << "__sycl_device_global_registration::__sycl_device_global_"
5142+
"registration() noexcept {\n";
5143+
OS << DeviceGlobalsBuf;
5144+
OS << "}\n";
5145+
OS << "} // namespace (unnamed)\n";
5146+
OS << "} // namespace sycl::detail\n";
5147+
}
50975148
return true;
50985149
}
50995150

@@ -5138,6 +5189,18 @@ bool Util::isSyclSpecIdType(QualType Ty) {
51385189
return matchQualifiedTypeName(Ty, Scopes);
51395190
}
51405191

5192+
bool Util::isSyclDeviceGlobalType(QualType Ty) {
5193+
const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl();
5194+
if (!RecTy)
5195+
return false;
5196+
if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RecTy)) {
5197+
ClassTemplateDecl *Template = CTSD->getSpecializedTemplate();
5198+
if (CXXRecordDecl *RD = Template->getTemplatedDecl())
5199+
return RD->hasAttr<SYCLDeviceGlobalAttr>();
5200+
}
5201+
return RecTy->hasAttr<SYCLDeviceGlobalAttr>();
5202+
}
5203+
51415204
bool Util::isSyclKernelHandlerType(QualType Ty) {
51425205
std::array<DeclContextDesc, 3> Scopes = {
51435206
Util::MakeDeclContextDesc(Decl::Kind::Namespace, "cl"),

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ struct no_alias {
129129
template <bool> class instance {};
130130
};
131131
} // namespace property
132+
133+
template <typename T>
134+
class [[__sycl_detail__::device_global]] device_global {
135+
public:
136+
const T &get() const noexcept { return *Data; }
137+
device_global() {}
138+
139+
private:
140+
T *Data;
141+
};
142+
132143
} // namespace oneapi
133144
} // namespace ext
134145

0 commit comments

Comments
 (0)