Skip to content

[SYCL] Emit integration footer and header for device_global variables #5576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 3, 2022
Merged
9 changes: 9 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,15 @@ def SYCLUsesAspects : InheritableAttr {
let Documentation = [Undocumented];
}

def SYCLDeviceGlobal : InheritableAttr {
let Spellings = [CXX11<"__sycl_detail__", "device_global">];
let Subjects = SubjectList<[CXXRecord], ErrorDiag>;
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
// Only used internally by the SYCL implementation
let Documentation = [Undocumented];
let SimpleHandler = 1;
}

// Marks functions which must not be vectorized via horizontal SIMT widening,
// e.g. because the function is already vectorized. Used to mark SYCL
// explicit SIMD kernels and functions.
Expand Down
4 changes: 3 additions & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,14 @@ class SYCLIntegrationFooter {
SYCLIntegrationFooter(Sema &S) : S(S) {}
bool emit(StringRef MainSrc);
void addVarDecl(const VarDecl *VD);
bool isDeviceGlobalsEmitted() { return DeviceGlobalsEmitted; }

private:
bool emit(raw_ostream &O);
Sema &S;
llvm::SmallVector<const VarDecl *> SpecConstants;
llvm::SmallVector<const VarDecl *> GlobalVars;
void emitSpecIDName(raw_ostream &O, const VarDecl *VD);
bool DeviceGlobalsEmitted = false;
};

/// Tracks expected type during expression parsing, for use in code completion.
Expand Down
7 changes: 5 additions & 2 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,14 @@ void Sema::ActOnEndOfTranslationUnitFragment(TUFragmentKind Kind) {
// Set the names of the kernels, now that the names have settled down. This
// needs to happen before we generate the integration headers.
SetSYCLKernelNames();
// Make sure that the footer is emitted before header, since only after the
// footer is emitted is it known that translation unit contains device
// global variables.
if (SyclIntFooter != nullptr)
SyclIntFooter->emit(getLangOpts().SYCLIntFooter);
// Emit SYCL integration header for current translation unit if needed
if (SyclIntHeader != nullptr)
SyclIntHeader->emit(getLangOpts().SYCLIntHeader);
if (SyclIntFooter != nullptr)
SyclIntFooter->emit(getLangOpts().SYCLIntFooter);
MarkDevices();
}

Expand Down
167 changes: 115 additions & 52 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class Util {
/// specialization id class.
static bool isSyclSpecIdType(QualType Ty);

/// Checks whether given clang type is a full specialization of the SYCL
/// device_global class.
static bool isSyclDeviceGlobalType(QualType Ty);

/// Checks whether given clang type is a full specialization of the SYCL
/// kernel_handler class.
static bool isSyclKernelHandlerType(QualType Ty);
Expand Down Expand Up @@ -4676,7 +4680,23 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
O << "namespace sycl {\n";
O << "namespace detail {\n";

O << "\n";
// Generate declaration of variable of type __sycl_device_global_registration
// whose sole purpose is to run its constructor before the application's
// main() function.

if (S.getSyclIntegrationFooter().isDeviceGlobalsEmitted()) {
O << "namespace {\n";

O << "class __sycl_device_global_registration {\n";
O << "public:\n";
O << " __sycl_device_global_registration() noexcept;\n";
O << "};\n";
O << "__sycl_device_global_registration __sycl_device_global_registrar;\n";

O << "} // namespace\n";

O << "\n";
}

O << "// names of all kernels defined in the corresponding source\n";
O << "static constexpr\n";
Expand Down Expand Up @@ -4858,9 +4878,9 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
// template instantiations as a VarDecl.
if (isa<VarTemplatePartialSpecializationDecl>(VD))
return;
// Step 1: ensure that this is of the correct type-spec-constant template
// specialization).
if (!Util::isSyclSpecIdType(VD->getType())) {
// Step 1: ensure that this is of the correct type template specialization.
if (!Util::isSyclSpecIdType(VD->getType()) &&
!Util::isSyclDeviceGlobalType(VD->getType())) {
// Handle the case where this could be a deduced type, such as a deduction
// guide. We have to do this here since this function, unlike most of the
// rest of this file, is called during Sema instead of after it. We will
Expand All @@ -4876,8 +4896,8 @@ void SYCLIntegrationFooter::addVarDecl(const VarDecl *VD) {
// let an error happen during host compilation.
if (!VD->hasGlobalStorage() || VD->isLocalVarDeclOrParm())
return;
// Step 3: Add to SpecConstants collection.
SpecConstants.push_back(VD);
// Step 3: Add to collection.
GlobalVars.push_back(VD);
}

// Post-compile integration header support.
Expand Down Expand Up @@ -4951,29 +4971,28 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
}

static std::string EmitSpecIdShim(raw_ostream &OS, unsigned &ShimCounter,
const std::string &LastShim,
const NamespaceDecl *AnonNS) {
static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter,
const std::string &LastShim,
const NamespaceDecl *AnonNS) {
std::string NewShimName =
"__sycl_detail::__spec_id_shim_" + std::to_string(ShimCounter) + "()";
"__sycl_detail::__shim_" + std::to_string(ShimCounter) + "()";
// Print opening-namespace
PrintNamespaces(OS, Decl::castToDeclContext(AnonNS));
OS << "namespace __sycl_detail {\n";
OS << "static constexpr decltype(" << LastShim << ") &__spec_id_shim_"
<< ShimCounter << "() {\n";
OS << "static constexpr decltype(" << LastShim << ") &__shim_" << ShimCounter
<< "() {\n";
OS << " return " << LastShim << ";\n";
OS << "}\n";
OS << "} // namespace __sycl_detail \n";
OS << "} // namespace __sycl_detail\n";
PrintNSClosingBraces(OS, Decl::castToDeclContext(AnonNS));

++ShimCounter;
return NewShimName;
}

// Emit the list of shims required for a DeclContext, calls itself recursively.
static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
const DeclContext *DC,
std::string &NameForLastShim) {
static void EmitShims(raw_ostream &OS, unsigned &ShimCounter,
const DeclContext *DC, std::string &NameForLastShim) {
if (DC->isTranslationUnit()) {
NameForLastShim = "::" + NameForLastShim;
return;
Expand All @@ -4987,7 +5006,7 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
} else if (const auto *ND = dyn_cast<NamespaceDecl>(CurDecl)) {
if (ND->isAnonymousNamespace()) {
// Print current shim, reset 'name for last shim'.
NameForLastShim = EmitSpecIdShim(OS, ShimCounter, NameForLastShim, ND);
NameForLastShim = EmitShim(OS, ShimCounter, NameForLastShim, ND);
} else {
NameForLastShim = ND->getNameAsString() + "::" + NameForLastShim;
}
Expand All @@ -5001,22 +5020,22 @@ static void EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
"Unhandled decl type");
}

EmitSpecIdShims(OS, ShimCounter, CurDecl->getDeclContext(), NameForLastShim);
EmitShims(OS, ShimCounter, CurDecl->getDeclContext(), NameForLastShim);
}

// Emit the list of shims required for a variable declaration.
// Returns a string containing the FQN of the 'top most' shim, including its
// function call parameters.
static std::string EmitSpecIdShims(raw_ostream &OS, unsigned &ShimCounter,
PrintingPolicy &Policy, const VarDecl *VD) {
static std::string EmitShims(raw_ostream &OS, unsigned &ShimCounter,
PrintingPolicy &Policy, const VarDecl *VD) {
if (!VD->isInAnonymousNamespace())
return "";
std::string RelativeName;
llvm::raw_string_ostream stream(RelativeName);
VD->getNameForDiagnostic(stream, Policy, false);
stream.flush();

EmitSpecIdShims(OS, ShimCounter, VD->getDeclContext(), RelativeName);
EmitShims(OS, ShimCounter, VD->getDeclContext(), RelativeName);
return RelativeName;
}

Expand All @@ -5026,58 +5045,90 @@ bool SYCLIntegrationFooter::emit(raw_ostream &OS) {
Policy.SuppressTypedefs = true;
Policy.SuppressUnwrittenScope = true;

llvm::SmallSet<const VarDecl *, 8> VisitedSpecConstants;
llvm::SmallSet<const VarDecl *, 8> Visited;
bool EmittedFirstSpecConstant = false;

// Used to uniquely name the 'shim's as we generate the names in each
// anonymous namespace.
unsigned ShimCounter = 0;
for (const VarDecl *VD : SpecConstants) {

std::string DeviceGlobalsBuf;
llvm::raw_string_ostream DeviceGlobOS(DeviceGlobalsBuf);
for (const VarDecl *VD : GlobalVars) {
VD = VD->getCanonicalDecl();

// Skip if this isn't a SpecIdType. This can happen if it was a deduced
// type.
if (!Util::isSyclSpecIdType(VD->getType()))
// Skip if this isn't a SpecIdType or DeviceGlobal. This can happen if it
// was a deduced type.
if (!Util::isSyclSpecIdType(VD->getType()) &&
!Util::isSyclDeviceGlobalType(VD->getType()))
continue;

// Skip if we've already visited this.
if (llvm::find(VisitedSpecConstants, VD) != VisitedSpecConstants.end())
if (llvm::find(Visited, VD) != Visited.end())
continue;

// We only want to emit the #includes if we have a spec-constant that needs
// We only want to emit the #includes if we have a variable that needs
// them, so emit this one on the first time through the loop.
if (!EmittedFirstSpecConstant)
if (!EmittedFirstSpecConstant && !DeviceGlobalsEmitted)
OS << "#include <CL/sycl/detail/defines_elementary.hpp>\n";
EmittedFirstSpecConstant = true;

VisitedSpecConstants.insert(VD);
std::string TopShim = EmitSpecIdShims(OS, ShimCounter, Policy, VD);
OS << "__SYCL_INLINE_NAMESPACE(cl) {\n";
OS << "namespace sycl {\n";
OS << "namespace detail {\n";
OS << "template<>\n";
OS << "inline const char *get_spec_constant_symbolic_ID_impl<";

if (VD->isInAnonymousNamespace()) {
OS << TopShim;

Visited.insert(VD);
std::string TopShim = EmitShims(OS, ShimCounter, Policy, VD);
if (Util::isSyclDeviceGlobalType(VD->getType())) {
DeviceGlobalsEmitted = true;
DeviceGlobOS << "device_global_map::add(";
DeviceGlobOS << "(void *)&";
if (VD->isInAnonymousNamespace()) {
DeviceGlobOS << TopShim;
} else {
DeviceGlobOS << "::";
VD->getNameForDiagnostic(DeviceGlobOS, Policy, true);
}
DeviceGlobOS << ", \"";
DeviceGlobOS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(),
VD);
DeviceGlobOS << "\");\n";
} else {
OS << "::";
VD->getNameForDiagnostic(OS, Policy, true);
}
EmittedFirstSpecConstant = true;
OS << "__SYCL_INLINE_NAMESPACE(cl) {\n";
OS << "namespace sycl {\n";
OS << "namespace detail {\n";
OS << "template<>\n";
OS << "inline const char *get_spec_constant_symbolic_ID_impl<";

if (VD->isInAnonymousNamespace()) {
OS << TopShim;
} else {
OS << "::";
VD->getNameForDiagnostic(OS, Policy, true);
}

OS << ">() {\n";
OS << " return \"";
OS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(), VD);
OS << "\";\n";
OS << "}\n";
OS << "} // namespace detail\n";
OS << "} // namespace sycl\n";
OS << "} // __SYCL_INLINE_NAMESPACE(cl)\n";
OS << ">() {\n";
OS << " return \"";
OS << SYCLUniqueStableIdExpr::ComputeName(S.getASTContext(), VD);
OS << "\";\n";
OS << "}\n";
OS << "} // namespace detail\n";
OS << "} // namespace sycl\n";
OS << "} // __SYCL_INLINE_NAMESPACE(cl)\n";
}
}

if (EmittedFirstSpecConstant)
OS << "#include <CL/sycl/detail/spec_const_integration.hpp>\n";

if (DeviceGlobalsEmitted) {
OS << "#include <CL/sycl/detail/device_global_map.hpp>\n";
DeviceGlobOS.flush();
OS << "namespace sycl::detail {\n";
OS << "namespace {\n";
OS << "__sycl_device_global_registration::__sycl_device_global_"
"registration() noexcept {\n";
OS << DeviceGlobalsBuf;
OS << "}\n";
OS << "} // namespace (unnamed)\n";
OS << "} // namespace sycl::detail\n";
}
return true;
}

Expand Down Expand Up @@ -5122,6 +5173,18 @@ bool Util::isSyclSpecIdType(QualType Ty) {
return matchQualifiedTypeName(Ty, Scopes);
}

bool Util::isSyclDeviceGlobalType(QualType Ty) {
const CXXRecordDecl *RecTy = Ty->getAsCXXRecordDecl();
if (!RecTy)
return false;
if (auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(RecTy)) {
ClassTemplateDecl *Template = CTSD->getSpecializedTemplate();
if (CXXRecordDecl *RD = Template->getTemplatedDecl())
return RD->hasAttr<SYCLDeviceGlobalAttr>();
}
return RecTy->hasAttr<SYCLDeviceGlobalAttr>();
}

bool Util::isSyclKernelHandlerType(QualType Ty) {
std::array<DeclContextDesc, 3> Scopes = {
Util::MakeDeclContextDesc(Decl::Kind::Namespace, "cl"),
Expand Down
11 changes: 11 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@ struct no_alias {
template <bool> class instance {};
};
} // namespace property

template <typename T>
class [[__sycl_detail__::device_global]] device_global {
public:
const T &get() const noexcept { return *Data; }
device_global() {}

private:
T *Data;
};

} // namespace oneapi
} // namespace ext

Expand Down
Loading