@@ -378,11 +378,12 @@ class kernel_bundle_impl {
378
378
379
379
// oneapi_ext_kernel_compiler
380
380
// program manager integration, only for sycl_jit language
381
- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
382
- const std::vector<kernel_id> &KernelIDs,
383
- std::vector<std::string> KNames,
384
- sycl_device_binaries Binaries, std::string Pfx,
385
- syclex::source_language Lang)
381
+ kernel_bundle_impl (
382
+ context Ctx, std::vector<device> Devs,
383
+ const std::vector<kernel_id> &KernelIDs,
384
+ std::unordered_map<std::string, std::string> &&MangledKernelNames,
385
+ sycl_device_binaries Binaries, std::string &&Prefix,
386
+ syclex::source_language Lang)
386
387
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs,
387
388
bundle_state::executable) {
388
389
assert (Lang == syclex::source_language::sycl_jit);
@@ -392,9 +393,9 @@ class kernel_bundle_impl {
392
393
// loaded via the program manager have `kernel_id`s, they can't be looked up
393
394
// from the (unprefixed) kernel name.
394
395
MIsInterop = true ;
395
- MKernelNames = std::move (KNames );
396
+ MMangledKernelNames = std::move (MangledKernelNames );
396
397
MDeviceBinaries = Binaries;
397
- MPrefix = std::move (Pfx );
398
+ MPrefix = std::move (Prefix );
398
399
MLanguage = Lang;
399
400
}
400
401
@@ -501,15 +502,35 @@ class kernel_bundle_impl {
501
502
// TODO: Support persistent caching.
502
503
503
504
const std::string &SourceStr = std::get<std::string>(MSource);
505
+ std::ostringstream SourceExt;
506
+ if (!RegisteredKernelNames.empty ()) {
507
+ SourceExt << SourceStr << ' \n ' ;
508
+
509
+ auto EmitEntry =
510
+ [&SourceExt](const std::string &Name) -> std::ostringstream & {
511
+ SourceExt << " {\" " << Name << " \" , " << Name << " }" ;
512
+ return SourceExt;
513
+ };
514
+
515
+ SourceExt << " [[__sycl_detail__::__registered_kernels__(\n " ;
516
+ for (auto It = RegisteredKernelNames.begin (),
517
+ SecondToLast = RegisteredKernelNames.end () - 1 ;
518
+ It != SecondToLast; ++It) {
519
+ EmitEntry (*It) << " ,\n " ;
520
+ }
521
+ EmitEntry (RegisteredKernelNames.back ()) << " \n " ;
522
+ SourceExt << " )]];\n " ;
523
+ }
524
+
504
525
auto [Binaries, CompilationID] = syclex::detail::SYCL_JIT_to_SPIRV (
505
- SourceStr, MIncludePairs, BuildOptions, LogPtr ,
506
- RegisteredKernelNames );
526
+ RegisteredKernelNames. empty () ? SourceStr : SourceExt. str () ,
527
+ MIncludePairs, BuildOptions, LogPtr );
507
528
508
529
auto &PM = detail::ProgramManager::getInstance ();
509
530
PM.addImages (Binaries);
510
531
511
532
std::vector<kernel_id> KernelIDs;
512
- std::vector <std::string> KernelNames ;
533
+ std::unordered_map <std::string, std::string> MangledKernelNames ;
513
534
// `jit_compiler::compileSYCL(..)` uses `CompilationID + '$'` as prefix
514
535
// for offload entry names.
515
536
std::string Prefix = CompilationID + ' $' ;
@@ -518,13 +539,38 @@ class kernel_bundle_impl {
518
539
if (KernelName.find (Prefix) == 0 ) {
519
540
KernelIDs.push_back (KernelID);
520
541
KernelName.remove_prefix (Prefix.length ());
521
- KernelNames.emplace_back (KernelName);
542
+ static constexpr std::string_view SYCLKernelMarker{" __sycl_kernel_" };
543
+ if (KernelName.find (SYCLKernelMarker) == 0 ) {
544
+ // extern "C" declaration, register kernel without the marker.
545
+ std::string_view KernelNameWithoutMarker{KernelName};
546
+ KernelNameWithoutMarker.remove_prefix (SYCLKernelMarker.length ());
547
+ MangledKernelNames.emplace (KernelNameWithoutMarker, KernelName);
548
+ } else {
549
+ // The marker is baked into the mangling, and we cannot easily
550
+ // adjust it. Register an identity mapping as an escape hatch.
551
+ // Users shall use `registered_kernel_names` instead, as there's
552
+ // practically no way to guess the mangled name.
553
+ MangledKernelNames.emplace (KernelName, KernelName);
554
+ }
522
555
}
523
556
}
524
557
525
- return std::make_shared<kernel_bundle_impl>(MContext, MDevices, KernelIDs,
526
- KernelNames, Binaries, Prefix,
527
- MLanguage);
558
+ // Apply frontend information.
559
+ for (const auto *RawImg : PM.getRawDeviceImages (KernelIDs)) {
560
+ for (const sycl_device_binary_property &RKProp :
561
+ RawImg->getRegisteredKernels ()) {
562
+
563
+ auto BA = DeviceBinaryProperty (RKProp).asByteArray ();
564
+ auto MangledNameLen = BA.consume <uint64_t >() / 8 /* bits in a byte*/ ;
565
+ std::string_view MangledName{
566
+ reinterpret_cast <const char *>(BA.begin ()), MangledNameLen};
567
+ MangledKernelNames.emplace (RKProp->Name , MangledName);
568
+ }
569
+ }
570
+
571
+ return std::make_shared<kernel_bundle_impl>(
572
+ MContext, MDevices, KernelIDs, std::move (MangledKernelNames),
573
+ Binaries, std::move (Prefix), MLanguage);
528
574
}
529
575
530
576
ur_program_handle_t UrProgram = nullptr ;
@@ -642,6 +688,9 @@ class kernel_bundle_impl {
642
688
}
643
689
644
690
bool ext_oneapi_has_kernel (const std::string &Name) {
691
+ if (MLanguage == syclex::source_language::sycl_jit) {
692
+ return MMangledKernelNames.count (Name);
693
+ }
645
694
auto it = std::find (MKernelNames.begin (), MKernelNames.end (),
646
695
adjust_kernel_name (Name, MLanguage));
647
696
return it != MKernelNames.end ();
@@ -650,21 +699,25 @@ class kernel_bundle_impl {
650
699
kernel
651
700
ext_oneapi_get_kernel (const std::string &Name,
652
701
const std::shared_ptr<kernel_bundle_impl> &Self) {
653
- if (MKernelNames.empty ())
654
- throw sycl::exception (make_error_code (errc::invalid),
655
- " 'ext_oneapi_get_kernel' is only available in "
656
- " kernel_bundles successfully built from "
657
- " kernel_bundle<bundle_state:ext_oneapi_source>." );
702
+ if (MLanguage == syclex::source_language::sycl_jit) {
703
+ if (MMangledKernelNames.empty ()) {
704
+ throw sycl::exception (
705
+ make_error_code (errc::invalid),
706
+ " 'ext_oneapi_get_kernel' is only available in kernel_bundles "
707
+ " successfully built from "
708
+ " kernel_bundle<bundle_state::ext_oneapi_source>." );
709
+ }
658
710
659
- std::string AdjustedName = adjust_kernel_name (Name, MLanguage);
660
- if (!ext_oneapi_has_kernel (Name))
661
- throw sycl::exception (make_error_code (errc::invalid),
662
- " kernel '" + AdjustedName +
663
- " ' not found in kernel_bundle" );
711
+ auto It = MMangledKernelNames.find (Name);
712
+ if (It == MMangledKernelNames.end ()) {
713
+ throw sycl::exception (make_error_code (errc::invalid),
714
+ " kernel '" + Name +
715
+ " ' not found in kernel_bundle" );
716
+ }
664
717
665
- if (MLanguage == syclex::source_language::sycl_jit) {
718
+ const std::string &MangledName = It-> second ;
666
719
auto &PM = ProgramManager::getInstance ();
667
- auto KID = PM.getSYCLKernelID (MPrefix + AdjustedName );
720
+ auto KID = PM.getSYCLKernelID (MPrefix + MangledName );
668
721
669
722
for (const auto &DevImgWithDeps : MDeviceImages) {
670
723
const auto &DevImg = DevImgWithDeps.getMain ();
@@ -674,7 +727,7 @@ class kernel_bundle_impl {
674
727
const auto &DevImgImpl = getSyclObjImpl (DevImg);
675
728
auto UrProgram = DevImgImpl->get_ur_program_ref ();
676
729
auto [UrKernel, CacheMutex, ArgMask] =
677
- PM.getOrCreateKernel (MContext, AdjustedName ,
730
+ PM.getOrCreateKernel (MContext, MangledName ,
678
731
/* PropList=*/ {}, UrProgram);
679
732
auto KernelImpl = std::make_shared<kernel_impl>(
680
733
UrKernel, getSyclObjImpl (MContext), DevImgImpl, Self, ArgMask,
@@ -685,6 +738,18 @@ class kernel_bundle_impl {
685
738
assert (false && " Malformed RTC kernel bundle" );
686
739
}
687
740
741
+ if (MKernelNames.empty ())
742
+ throw sycl::exception (make_error_code (errc::invalid),
743
+ " 'ext_oneapi_get_kernel' is only available in "
744
+ " kernel_bundles successfully built from "
745
+ " kernel_bundle<bundle_state:ext_oneapi_source>." );
746
+
747
+ std::string AdjustedName = adjust_kernel_name (Name, MLanguage);
748
+ if (!ext_oneapi_has_kernel (Name))
749
+ throw sycl::exception (make_error_code (errc::invalid),
750
+ " kernel '" + AdjustedName +
751
+ " ' not found in kernel_bundle" );
752
+
688
753
assert (MDeviceImages.size () > 0 );
689
754
const std::shared_ptr<detail::device_image_impl> &DeviceImageImpl =
690
755
detail::getSyclObjImpl (MDeviceImages[0 ].getMain ());
@@ -877,12 +942,11 @@ class kernel_bundle_impl {
877
942
}
878
943
879
944
bool is_specialization_constant_set (const char *SpecName) const noexcept {
880
- bool SetInDevImg =
881
- std::any_of (begin (), end (),
882
- [SpecName](const device_image_plain &DeviceImage) {
883
- return getSyclObjImpl (DeviceImage)
884
- ->is_specialization_constant_set (SpecName);
885
- });
945
+ bool SetInDevImg = std::any_of (
946
+ begin (), end (), [SpecName](const device_image_plain &DeviceImage) {
947
+ return getSyclObjImpl (DeviceImage)
948
+ ->is_specialization_constant_set (SpecName);
949
+ });
886
950
return SetInDevImg || MSpecConstValues.count (std::string{SpecName}) != 0 ;
887
951
}
888
952
@@ -973,6 +1037,7 @@ class kernel_bundle_impl {
973
1037
const std::variant<std::string, std::vector<std::byte>> MSource;
974
1038
// only kernel_bundles created from source have KernelNames member.
975
1039
std::vector<std::string> MKernelNames;
1040
+ std::unordered_map<std::string, std::string> MMangledKernelNames;
976
1041
sycl_device_binaries MDeviceBinaries = nullptr ;
977
1042
std::string MPrefix;
978
1043
include_pairs_t MIncludePairs;
0 commit comments