Skip to content

Commit 58dfefb

Browse files
committed
[SYCL][RTC] Query kernels by source code name
Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com>
1 parent 47630fe commit 58dfefb

File tree

9 files changed

+123
-68
lines changed

9 files changed

+123
-68
lines changed

sycl/source/detail/compiler.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
#define __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS "SYCL/virtual functions"
6969
/// PropertySetRegistry::SYCL_IMPLICIT_LOCAL_ARG defined in PropertySetIO.h
7070
#define __SYCL_PROPERTY_SET_SYCL_IMPLICIT_LOCAL_ARG "SYCL/implicit local arg"
71+
/// PropertySetRegistry::SYCL_REGISTERED_KERNELS defined in PropertySetIO.h
72+
#define __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS "SYCL/registered kernels"
7173

7274
/// Program metadata tags recognized by the PI backends. For kernels the tag
7375
/// must appear after the kernel name.

sycl/source/detail/device_binary_image.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
195195
DeviceRequirements.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
196196
HostPipes.init(Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES);
197197
VirtualFunctions.init(Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);
198+
RegisteredKernels.init(Bin, __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS);
198199

199200
ImageId = ImageCounter++;
200201
}

sycl/source/detail/device_binary_image.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ class RTDeviceBinaryImage {
232232
const PropertyRange &getHostPipes() const { return HostPipes; }
233233
const PropertyRange &getVirtualFunctions() const { return VirtualFunctions; }
234234
const PropertyRange &getImplicitLocalArg() const { return ImplicitLocalArg; }
235+
const PropertyRange &getRegisteredKernels() const {
236+
return RegisteredKernels;
237+
}
235238

236239
std::uintptr_t getImageID() const {
237240
assert(Bin && "Image ID is not available without a binary image.");
@@ -258,6 +261,7 @@ class RTDeviceBinaryImage {
258261
RTDeviceBinaryImage::PropertyRange HostPipes;
259262
RTDeviceBinaryImage::PropertyRange VirtualFunctions;
260263
RTDeviceBinaryImage::PropertyRange ImplicitLocalArg;
264+
RTDeviceBinaryImage::PropertyRange RegisteredKernels;
261265

262266
std::vector<ur_program_metadata_t> ProgramMetadataUR;
263267

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,24 +1249,11 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
12491249
sycl_device_binaries jit_compiler::compileSYCL(
12501250
const std::string &CompilationID, const std::string &SYCLSource,
12511251
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
1252-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
1253-
const std::vector<std::string> &RegisteredKernelNames) {
1254-
1255-
// RegisteredKernelNames may contain template specializations, so we just put
1256-
// them in main() which ensures they are instantiated.
1257-
std::ostringstream ss;
1258-
ss << SYCLSource << '\n';
1259-
ss << "int main() {\n";
1260-
for (const std::string &KernelName : RegisteredKernelNames) {
1261-
ss << " (void)" << KernelName << ";\n";
1262-
}
1263-
ss << " return 0;\n}\n" << std::endl;
1264-
1265-
std::string FinalSource = ss.str();
1252+
const std::vector<std::string> &UserArgs, std::string *LogPtr) {
12661253

12671254
std::string SYCLFileName = CompilationID + ".cpp";
12681255
::jit_compiler::InMemoryFile SourceFile{SYCLFileName.c_str(),
1269-
FinalSource.c_str()};
1256+
SYCLSource.c_str()};
12701257

12711258
std::vector<::jit_compiler::InMemoryFile> IncludeFilesView;
12721259
IncludeFilesView.reserve(IncludePairs.size());

sycl/source/detail/jit_compiler.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class jit_compiler {
5252
sycl_device_binaries compileSYCL(
5353
const std::string &CompilationID, const std::string &SYCLSource,
5454
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
55-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
56-
const std::vector<std::string> &RegisteredKernelNames);
55+
const std::vector<std::string> &UserArgs, std::string *LogPtr);
5756

5857
void destroyDeviceBinaries(sycl_device_binaries Binaries);
5958

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,12 @@ class kernel_bundle_impl {
378378

379379
// oneapi_ext_kernel_compiler
380380
// program manager integration, only for sycl_jit language
381-
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
382-
const std::vector<kernel_id> &KernelIDs,
383-
std::vector<std::string> KNames,
384-
sycl_device_binaries Binaries, std::string Pfx,
385-
syclex::source_language Lang)
381+
kernel_bundle_impl(
382+
context Ctx, std::vector<device> Devs,
383+
const std::vector<kernel_id> &KernelIDs,
384+
std::unordered_map<std::string, std::string> &&MangledKernelNames,
385+
sycl_device_binaries Binaries, std::string &&Prefix,
386+
syclex::source_language Lang)
386387
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs,
387388
bundle_state::executable) {
388389
assert(Lang == syclex::source_language::sycl_jit);
@@ -392,9 +393,9 @@ class kernel_bundle_impl {
392393
// loaded via the program manager have `kernel_id`s, they can't be looked up
393394
// from the (unprefixed) kernel name.
394395
MIsInterop = true;
395-
MKernelNames = std::move(KNames);
396+
MMangledKernelNames = std::move(MangledKernelNames);
396397
MDeviceBinaries = Binaries;
397-
MPrefix = std::move(Pfx);
398+
MPrefix = std::move(Prefix);
398399
MLanguage = Lang;
399400
}
400401

@@ -501,15 +502,35 @@ class kernel_bundle_impl {
501502
// TODO: Support persistent caching.
502503

503504
const std::string &SourceStr = std::get<std::string>(MSource);
505+
std::ostringstream SourceExt;
506+
if (!RegisteredKernelNames.empty()) {
507+
SourceExt << SourceStr << '\n';
508+
509+
auto EmitEntry =
510+
[&SourceExt](const std::string &Name) -> std::ostringstream & {
511+
SourceExt << " {\"" << Name << "\", " << Name << "}";
512+
return SourceExt;
513+
};
514+
515+
SourceExt << "[[__sycl_detail__::__registered_kernels__(\n";
516+
for (auto It = RegisteredKernelNames.begin(),
517+
SecondToLast = RegisteredKernelNames.end() - 1;
518+
It != SecondToLast; ++It) {
519+
EmitEntry(*It) << ",\n";
520+
}
521+
EmitEntry(RegisteredKernelNames.back()) << "\n";
522+
SourceExt << ")]];\n";
523+
}
524+
504525
auto [Binaries, CompilationID] = syclex::detail::SYCL_JIT_to_SPIRV(
505-
SourceStr, MIncludePairs, BuildOptions, LogPtr,
506-
RegisteredKernelNames);
526+
RegisteredKernelNames.empty() ? SourceStr : SourceExt.str(),
527+
MIncludePairs, BuildOptions, LogPtr);
507528

508529
auto &PM = detail::ProgramManager::getInstance();
509530
PM.addImages(Binaries);
510531

511532
std::vector<kernel_id> KernelIDs;
512-
std::vector<std::string> KernelNames;
533+
std::unordered_map<std::string, std::string> MangledKernelNames;
513534
// `jit_compiler::compileSYCL(..)` uses `CompilationID + '$'` as prefix
514535
// for offload entry names.
515536
std::string Prefix = CompilationID + '$';
@@ -518,13 +539,38 @@ class kernel_bundle_impl {
518539
if (KernelName.find(Prefix) == 0) {
519540
KernelIDs.push_back(KernelID);
520541
KernelName.remove_prefix(Prefix.length());
521-
KernelNames.emplace_back(KernelName);
542+
static constexpr std::string_view SYCLKernelMarker{"__sycl_kernel_"};
543+
if (KernelName.find(SYCLKernelMarker) == 0) {
544+
// extern "C" declaration, register kernel without the marker.
545+
std::string_view KernelNameWithoutMarker{KernelName};
546+
KernelNameWithoutMarker.remove_prefix(SYCLKernelMarker.length());
547+
MangledKernelNames.emplace(KernelNameWithoutMarker, KernelName);
548+
} else {
549+
// The marker is baked into the mangling, and we cannot easily
550+
// adjust it. Register an identity mapping as an escape hatch.
551+
// Users shall use `registered_kernel_names` instead, as there's
552+
// practically no way to guess the mangled name.
553+
MangledKernelNames.emplace(KernelName, KernelName);
554+
}
522555
}
523556
}
524557

525-
return std::make_shared<kernel_bundle_impl>(MContext, MDevices, KernelIDs,
526-
KernelNames, Binaries, Prefix,
527-
MLanguage);
558+
// Apply frontend information.
559+
for (const auto *RawImg : PM.getRawDeviceImages(KernelIDs)) {
560+
for (const sycl_device_binary_property &RKProp :
561+
RawImg->getRegisteredKernels()) {
562+
563+
auto BA = DeviceBinaryProperty(RKProp).asByteArray();
564+
auto MangledNameLen = BA.consume<uint64_t>() / 8 /*bits in a byte*/;
565+
std::string_view MangledName{
566+
reinterpret_cast<const char *>(BA.begin()), MangledNameLen};
567+
MangledKernelNames.emplace(RKProp->Name, MangledName);
568+
}
569+
}
570+
571+
return std::make_shared<kernel_bundle_impl>(
572+
MContext, MDevices, KernelIDs, std::move(MangledKernelNames),
573+
Binaries, std::move(Prefix), MLanguage);
528574
}
529575

530576
ur_program_handle_t UrProgram = nullptr;
@@ -642,6 +688,9 @@ class kernel_bundle_impl {
642688
}
643689

644690
bool ext_oneapi_has_kernel(const std::string &Name) {
691+
if (MLanguage == syclex::source_language::sycl_jit) {
692+
return MMangledKernelNames.count(Name);
693+
}
645694
auto it = std::find(MKernelNames.begin(), MKernelNames.end(),
646695
adjust_kernel_name(Name, MLanguage));
647696
return it != MKernelNames.end();
@@ -650,21 +699,25 @@ class kernel_bundle_impl {
650699
kernel
651700
ext_oneapi_get_kernel(const std::string &Name,
652701
const std::shared_ptr<kernel_bundle_impl> &Self) {
653-
if (MKernelNames.empty())
654-
throw sycl::exception(make_error_code(errc::invalid),
655-
"'ext_oneapi_get_kernel' is only available in "
656-
"kernel_bundles successfully built from "
657-
"kernel_bundle<bundle_state:ext_oneapi_source>.");
702+
if (MLanguage == syclex::source_language::sycl_jit) {
703+
if (MMangledKernelNames.empty()) {
704+
throw sycl::exception(
705+
make_error_code(errc::invalid),
706+
"'ext_oneapi_get_kernel' is only available in kernel_bundles "
707+
"successfully built from "
708+
"kernel_bundle<bundle_state::ext_oneapi_source>.");
709+
}
658710

659-
std::string AdjustedName = adjust_kernel_name(Name, MLanguage);
660-
if (!ext_oneapi_has_kernel(Name))
661-
throw sycl::exception(make_error_code(errc::invalid),
662-
"kernel '" + AdjustedName +
663-
"' not found in kernel_bundle");
711+
auto It = MMangledKernelNames.find(Name);
712+
if (It == MMangledKernelNames.end()) {
713+
throw sycl::exception(make_error_code(errc::invalid),
714+
"kernel '" + Name +
715+
"' not found in kernel_bundle");
716+
}
664717

665-
if (MLanguage == syclex::source_language::sycl_jit) {
718+
const std::string &MangledName = It->second;
666719
auto &PM = ProgramManager::getInstance();
667-
auto KID = PM.getSYCLKernelID(MPrefix + AdjustedName);
720+
auto KID = PM.getSYCLKernelID(MPrefix + MangledName);
668721

669722
for (const auto &DevImgWithDeps : MDeviceImages) {
670723
const auto &DevImg = DevImgWithDeps.getMain();
@@ -674,7 +727,7 @@ class kernel_bundle_impl {
674727
const auto &DevImgImpl = getSyclObjImpl(DevImg);
675728
auto UrProgram = DevImgImpl->get_ur_program_ref();
676729
auto [UrKernel, CacheMutex, ArgMask] =
677-
PM.getOrCreateKernel(MContext, AdjustedName,
730+
PM.getOrCreateKernel(MContext, MangledName,
678731
/*PropList=*/{}, UrProgram);
679732
auto KernelImpl = std::make_shared<kernel_impl>(
680733
UrKernel, getSyclObjImpl(MContext), DevImgImpl, Self, ArgMask,
@@ -685,6 +738,18 @@ class kernel_bundle_impl {
685738
assert(false && "Malformed RTC kernel bundle");
686739
}
687740

741+
if (MKernelNames.empty())
742+
throw sycl::exception(make_error_code(errc::invalid),
743+
"'ext_oneapi_get_kernel' is only available in "
744+
"kernel_bundles successfully built from "
745+
"kernel_bundle<bundle_state:ext_oneapi_source>.");
746+
747+
std::string AdjustedName = adjust_kernel_name(Name, MLanguage);
748+
if (!ext_oneapi_has_kernel(Name))
749+
throw sycl::exception(make_error_code(errc::invalid),
750+
"kernel '" + AdjustedName +
751+
"' not found in kernel_bundle");
752+
688753
assert(MDeviceImages.size() > 0);
689754
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
690755
detail::getSyclObjImpl(MDeviceImages[0].getMain());
@@ -877,12 +942,11 @@ class kernel_bundle_impl {
877942
}
878943

879944
bool is_specialization_constant_set(const char *SpecName) const noexcept {
880-
bool SetInDevImg =
881-
std::any_of(begin(), end(),
882-
[SpecName](const device_image_plain &DeviceImage) {
883-
return getSyclObjImpl(DeviceImage)
884-
->is_specialization_constant_set(SpecName);
885-
});
945+
bool SetInDevImg = std::any_of(
946+
begin(), end(), [SpecName](const device_image_plain &DeviceImage) {
947+
return getSyclObjImpl(DeviceImage)
948+
->is_specialization_constant_set(SpecName);
949+
});
886950
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
887951
}
888952

@@ -973,6 +1037,7 @@ class kernel_bundle_impl {
9731037
const std::variant<std::string, std::vector<std::byte>> MSource;
9741038
// only kernel_bundles created from source have KernelNames member.
9751039
std::vector<std::string> MKernelNames;
1040+
std::unordered_map<std::string, std::string> MMangledKernelNames;
9761041
sycl_device_binaries MDeviceBinaries = nullptr;
9771042
std::string MPrefix;
9781043
include_pairs_t MIncludePairs;

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,17 @@ bool SYCL_JIT_Compilation_Available() {
303303
#endif
304304
}
305305

306-
std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
307-
[[maybe_unused]] const std::string &SYCLSource,
308-
[[maybe_unused]] include_pairs_t IncludePairs,
309-
[[maybe_unused]] const std::vector<std::string> &UserArgs,
310-
[[maybe_unused]] std::string *LogPtr,
311-
[[maybe_unused]] const std::vector<std::string> &RegisteredKernelNames) {
306+
std::pair<sycl_device_binaries, std::string>
307+
SYCL_JIT_to_SPIRV([[maybe_unused]] const std::string &SYCLSource,
308+
[[maybe_unused]] include_pairs_t IncludePairs,
309+
[[maybe_unused]] const std::vector<std::string> &UserArgs,
310+
[[maybe_unused]] std::string *LogPtr) {
312311
#if SYCL_EXT_JIT_ENABLE
313312
static std::atomic_uintptr_t CompilationCounter;
314313
std::string CompilationID = "rtc_" + std::to_string(CompilationCounter++);
315314
sycl_device_binaries Binaries =
316315
sycl::detail::jit_compiler::get_instance().compileSYCL(
317-
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr,
318-
RegisteredKernelNames);
316+
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr);
319317
return std::make_pair(Binaries, std::move(CompilationID));
320318
#else
321319
throw sycl::exception(sycl::errc::build,

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ std::string userArgsAsString(const std::vector<std::string> &UserArguments);
3737

3838
std::pair<sycl_device_binaries, std::string>
3939
SYCL_JIT_to_SPIRV(const std::string &Source, include_pairs_t IncludePairs,
40-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
41-
const std::vector<std::string> &RegisteredKernelNames);
40+
const std::vector<std::string> &UserArgs,
41+
std::string *LogPtr);
4242

4343
void SYCL_JIT_destroy(sycl_device_binaries Binaries);
4444

sycl/test-e2e/KernelCompiler/kernel_compiler_sycl_jit.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,13 @@ int test_build_and_run() {
207207
// directly.
208208
sycl::kernel k = kbExe2.ext_oneapi_get_kernel("ff_cp");
209209

210-
// The templated function name will have been mangled. Mapping from original
211-
// name to mangled is not yet supported. So we cannot yet do this:
212-
// sycl::kernel k2 = kbExe2.ext_oneapi_get_kernel("ff_templated<int>");
213-
214-
// Instead, we can TEMPORARILY use the mangled name. Once demangling is
215-
// supported this might no longer work.
216-
sycl::kernel k2 =
217-
kbExe2.ext_oneapi_get_kernel("_Z26__sycl_kernel_ff_templatedIiEvPT_S1_");
210+
// The templated function name will have been mangled.
211+
sycl::kernel k2 = kbExe2.ext_oneapi_get_kernel("ff_templated<int>");
212+
213+
// We can also use the mangled name. This escape hatch might be removed in the
214+
// future.
215+
assert(
216+
kbExe2.ext_oneapi_has_kernel("_Z26__sycl_kernel_ff_templatedIiEvPT_S1_"));
218217

219218
// Test the kernels.
220219
test_1(q, k, 37 + 5); // ff_cp seeds 37. AddEm will add 5 more.

0 commit comments

Comments
 (0)