Skip to content

[SYCL][RTC] Query kernels by source code name #17032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sycl/include/sycl/kernel_bundle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class __SYCL_EXPORT kernel_bundle_plain {
return ext_oneapi_get_kernel(detail::string_view{name});
}

std::string ext_oneapi_get_raw_kernel_name(const std::string &name) {
return std::string{
ext_oneapi_get_raw_kernel_name(detail::string_view{name}).c_str()};
}

protected:
// \returns a kernel object which represents the kernel identified by
// kernel_id passed
Expand Down Expand Up @@ -263,6 +268,7 @@ class __SYCL_EXPORT kernel_bundle_plain {
private:
bool ext_oneapi_has_kernel(detail::string_view name);
kernel ext_oneapi_get_kernel(detail::string_view name);
detail::string ext_oneapi_get_raw_kernel_name(detail::string_view name);
};

} // namespace detail
Expand Down Expand Up @@ -483,6 +489,16 @@ class kernel_bundle : public detail::kernel_bundle_plain,
return detail::kernel_bundle_plain::ext_oneapi_get_kernel(name);
}

/////////////////////////
// ext_oneapi_get_raw_kernel_name
// kernel_bundle must be created from source, throws if not present
/////////////////////////
template <bundle_state _State = State,
typename = std::enable_if_t<_State == bundle_state::executable>>
std::string ext_oneapi_get_raw_kernel_name(const std::string &name) {
return detail::kernel_bundle_plain::ext_oneapi_get_raw_kernel_name(name);
}

private:
kernel_bundle(detail::KernelBundleImplPtr Impl)
: kernel_bundle_plain(std::move(Impl)) {}
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
#define __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS "SYCL/virtual functions"
/// PropertySetRegistry::SYCL_IMPLICIT_LOCAL_ARG defined in PropertySetIO.h
#define __SYCL_PROPERTY_SET_SYCL_IMPLICIT_LOCAL_ARG "SYCL/implicit local arg"
/// PropertySetRegistry::SYCL_REGISTERED_KERNELS defined in PropertySetIO.h
#define __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS "SYCL/registered kernels"

/// Program metadata tags recognized by the PI backends. For kernels the tag
/// must appear after the kernel name.
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/device_binary_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
DeviceRequirements.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
HostPipes.init(Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES);
VirtualFunctions.init(Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);
RegisteredKernels.init(Bin, __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS);

ImageId = ImageCounter++;
}
Expand Down
4 changes: 4 additions & 0 deletions sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class RTDeviceBinaryImage {
const PropertyRange &getHostPipes() const { return HostPipes; }
const PropertyRange &getVirtualFunctions() const { return VirtualFunctions; }
const PropertyRange &getImplicitLocalArg() const { return ImplicitLocalArg; }
const PropertyRange &getRegisteredKernels() const {
return RegisteredKernels;
}

std::uintptr_t getImageID() const {
assert(Bin && "Image ID is not available without a binary image.");
Expand All @@ -258,6 +261,7 @@ class RTDeviceBinaryImage {
RTDeviceBinaryImage::PropertyRange HostPipes;
RTDeviceBinaryImage::PropertyRange VirtualFunctions;
RTDeviceBinaryImage::PropertyRange ImplicitLocalArg;
RTDeviceBinaryImage::PropertyRange RegisteredKernels;

std::vector<ur_program_metadata_t> ProgramMetadataUR;

Expand Down
17 changes: 2 additions & 15 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,29 +1259,16 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
std::pair<sycl_device_binaries, std::string> jit_compiler::compileSYCL(
const std::string &CompilationID, const std::string &SYCLSource,
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr,
const std::vector<std::string> &RegisteredKernelNames) {
const std::vector<std::string> &UserArgs, std::string *LogPtr) {
auto appendToLog = [LogPtr](const char *Msg) {
if (LogPtr) {
LogPtr->append(Msg);
}
};

// RegisteredKernelNames may contain template specializations, so we just put
// them in main() which ensures they are instantiated.
std::ostringstream ss;
ss << SYCLSource << '\n';
ss << "int main() {\n";
for (const std::string &KernelName : RegisteredKernelNames) {
ss << " (void)" << KernelName << ";\n";
}
ss << " return 0;\n}\n" << std::endl;

std::string FinalSource = ss.str();

std::string SYCLFileName = CompilationID + ".cpp";
::jit_compiler::InMemoryFile SourceFile{SYCLFileName.c_str(),
FinalSource.c_str()};
SYCLSource.c_str()};

std::vector<::jit_compiler::InMemoryFile> IncludeFilesView;
IncludeFilesView.reserve(IncludePairs.size());
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/jit_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class jit_compiler {
std::pair<sycl_device_binaries, std::string> compileSYCL(
const std::string &CompilationID, const std::string &SYCLSource,
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr,
const std::vector<std::string> &RegisteredKernelNames);
const std::vector<std::string> &UserArgs, std::string *LogPtr);

void destroyDeviceBinaries(sycl_device_binaries Binaries);

Expand Down
135 changes: 101 additions & 34 deletions sycl/source/detail/kernel_bundle_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,13 @@ class kernel_bundle_impl {

// oneapi_ext_kernel_compiler
// program manager integration, only for sycl_jit language
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
const std::vector<kernel_id> &KernelIDs,
std::vector<std::string> KNames,
sycl_device_binaries Binaries, std::string Pfx,
syclex::source_language Lang)
kernel_bundle_impl(
context Ctx, std::vector<device> Devs,
const std::vector<kernel_id> &KernelIDs,
std::vector<std::string> &&KernelNames,
std::unordered_map<std::string, std::string> &&MangledKernelNames,
sycl_device_binaries Binaries, std::string &&Prefix,
syclex::source_language Lang)
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs,
bundle_state::executable) {
assert(Lang == syclex::source_language::sycl_jit);
Expand All @@ -392,9 +394,10 @@ class kernel_bundle_impl {
// loaded via the program manager have `kernel_id`s, they can't be looked up
// from the (unprefixed) kernel name.
MIsInterop = true;
MKernelNames = std::move(KNames);
MKernelNames = std::move(KernelNames);
MMangledKernelNames = std::move(MangledKernelNames);
MDeviceBinaries = Binaries;
MPrefix = std::move(Pfx);
MPrefix = std::move(Prefix);
MLanguage = Lang;
}

Expand Down Expand Up @@ -499,27 +502,70 @@ class kernel_bundle_impl {
if (MLanguage == syclex::source_language::sycl_jit) {
// Build device images via the program manager.
const std::string &SourceStr = std::get<std::string>(MSource);
std::ostringstream SourceExt;
if (!RegisteredKernelNames.empty()) {
SourceExt << SourceStr << '\n';

auto EmitEntry =
[&SourceExt](const std::string &Name) -> std::ostringstream & {
SourceExt << " {\"" << Name << "\", " << Name << "}";
return SourceExt;
};

SourceExt << "[[__sycl_detail__::__registered_kernels__(\n";
for (auto It = RegisteredKernelNames.begin(),
SecondToLast = RegisteredKernelNames.end() - 1;
It != SecondToLast; ++It) {
EmitEntry(*It) << ",\n";
}
EmitEntry(RegisteredKernelNames.back()) << "\n";
SourceExt << ")]];\n";
}

auto [Binaries, Prefix] = syclex::detail::SYCL_JIT_to_SPIRV(
SourceStr, MIncludePairs, BuildOptions, LogPtr,
RegisteredKernelNames);
RegisteredKernelNames.empty() ? SourceStr : SourceExt.str(),
MIncludePairs, BuildOptions, LogPtr);

auto &PM = detail::ProgramManager::getInstance();
PM.addImages(Binaries);

std::vector<kernel_id> KernelIDs;
std::vector<std::string> KernelNames;
std::unordered_map<std::string, std::string> MangledKernelNames;
for (const auto &KernelID : PM.getAllSYCLKernelIDs()) {
std::string_view KernelName{KernelID.get_name()};
if (KernelName.find(Prefix) == 0) {
KernelIDs.push_back(KernelID);
KernelName.remove_prefix(Prefix.length());
KernelNames.emplace_back(KernelName);
static constexpr std::string_view SYCLKernelMarker{"__sycl_kernel_"};
if (KernelName.find(SYCLKernelMarker) == 0) {
// extern "C" declaration, implicitly register kernel without the
// marker.
std::string_view KernelNameWithoutMarker{KernelName};
KernelNameWithoutMarker.remove_prefix(SYCLKernelMarker.length());
MangledKernelNames.emplace(KernelNameWithoutMarker, KernelName);
}
}
}

return std::make_shared<kernel_bundle_impl>(MContext, MDevices, KernelIDs,
KernelNames, Binaries, Prefix,
MLanguage);
// Apply frontend information.
for (const auto *RawImg : PM.getRawDeviceImages(KernelIDs)) {
for (const sycl_device_binary_property &RKProp :
RawImg->getRegisteredKernels()) {

auto BA = DeviceBinaryProperty(RKProp).asByteArray();
auto MangledNameLen = BA.consume<uint64_t>() / 8 /*bits in a byte*/;
std::string_view MangledName{
reinterpret_cast<const char *>(BA.begin()), MangledNameLen};
MangledKernelNames.emplace(RKProp->Name, MangledName);
}
}

return std::make_shared<kernel_bundle_impl>(
MContext, MDevices, KernelIDs, std::move(KernelNames),
std::move(MangledKernelNames), Binaries, std::move(Prefix),
MLanguage);
}

ur_program_handle_t UrProgram = nullptr;
Expand Down Expand Up @@ -625,21 +671,27 @@ class kernel_bundle_impl {
KernelNames, MLanguage);
}

std::string adjust_kernel_name(const std::string &Name,
syclex::source_language Lang) {
// Once name demangling support is in, we won't need this.
if (Lang != syclex::source_language::sycl &&
Lang != syclex::source_language::sycl_jit)
return Name;
std::string adjust_kernel_name(const std::string &Name) {
if (MLanguage == syclex::source_language::sycl_jit) {
auto It = MMangledKernelNames.find(Name);
return It == MMangledKernelNames.end() ? Name : It->second;
}

bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
return isMangled ? Name : "__sycl_kernel_" + Name;
if (MLanguage == syclex::source_language::sycl) {
bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
return isMangled ? Name : "__sycl_kernel_" + Name;
}

return Name;
}

bool is_kernel_name(const std::string &Name) {
return std::find(MKernelNames.begin(), MKernelNames.end(), Name) !=
MKernelNames.end();
}

bool ext_oneapi_has_kernel(const std::string &Name) {
auto it = std::find(MKernelNames.begin(), MKernelNames.end(),
adjust_kernel_name(Name, MLanguage));
return it != MKernelNames.end();
return is_kernel_name(adjust_kernel_name(Name));
}

kernel
Expand All @@ -649,13 +701,12 @@ class kernel_bundle_impl {
throw sycl::exception(make_error_code(errc::invalid),
"'ext_oneapi_get_kernel' is only available in "
"kernel_bundles successfully built from "
"kernel_bundle<bundle_state:ext_oneapi_source>.");
"kernel_bundle<bundle_state::ext_oneapi_source>.");

std::string AdjustedName = adjust_kernel_name(Name, MLanguage);
if (!ext_oneapi_has_kernel(Name))
std::string AdjustedName = adjust_kernel_name(Name);
if (!is_kernel_name(AdjustedName))
throw sycl::exception(make_error_code(errc::invalid),
"kernel '" + AdjustedName +
"' not found in kernel_bundle");
"kernel '" + Name + "' not found in kernel_bundle");

if (MLanguage == syclex::source_language::sycl_jit) {
auto &PM = ProgramManager::getInstance();
Expand Down Expand Up @@ -697,6 +748,22 @@ class kernel_bundle_impl {
return detail::createSyclObjFromImpl<kernel>(KernelImpl);
}

std::string ext_oneapi_get_raw_kernel_name(const std::string &Name) {
if (MKernelNames.empty())
throw sycl::exception(
make_error_code(errc::invalid),
"'ext_oneapi_get_raw_kernel_name' is only available in "
"kernel_bundles successfully built from "
"kernel_bundle<bundle_state::ext_oneapi_source>.");

std::string AdjustedName = adjust_kernel_name(Name);
if (!is_kernel_name(AdjustedName))
throw sycl::exception(make_error_code(errc::invalid),
"kernel '" + Name + "' not found in kernel_bundle");

return AdjustedName;
}

bool empty() const noexcept { return MDeviceImages.empty(); }

backend get_backend() const noexcept {
Expand Down Expand Up @@ -872,12 +939,11 @@ class kernel_bundle_impl {
}

bool is_specialization_constant_set(const char *SpecName) const noexcept {
bool SetInDevImg =
std::any_of(begin(), end(),
[SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->is_specialization_constant_set(SpecName);
});
bool SetInDevImg = std::any_of(
begin(), end(), [SpecName](const device_image_plain &DeviceImage) {
return getSyclObjImpl(DeviceImage)
->is_specialization_constant_set(SpecName);
});
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
}

Expand Down Expand Up @@ -968,6 +1034,7 @@ class kernel_bundle_impl {
const std::variant<std::string, std::vector<std::byte>> MSource;
// only kernel_bundles created from source have KernelNames member.
std::vector<std::string> MKernelNames;
std::unordered_map<std::string, std::string> MMangledKernelNames;
sycl_device_binaries MDeviceBinaries = nullptr;
std::string MPrefix;
include_pairs_t MIncludePairs;
Expand Down
14 changes: 6 additions & 8 deletions sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,16 @@ bool SYCL_JIT_Compilation_Available() {
#endif
}

std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
[[maybe_unused]] const std::string &SYCLSource,
[[maybe_unused]] const include_pairs_t &IncludePairs,
[[maybe_unused]] const std::vector<std::string> &UserArgs,
[[maybe_unused]] std::string *LogPtr,
[[maybe_unused]] const std::vector<std::string> &RegisteredKernelNames) {
std::pair<sycl_device_binaries, std::string>
SYCL_JIT_to_SPIRV([[maybe_unused]] const std::string &SYCLSource,
[[maybe_unused]] const include_pairs_t &IncludePairs,
[[maybe_unused]] const std::vector<std::string> &UserArgs,
[[maybe_unused]] std::string *LogPtr) {
#if SYCL_EXT_JIT_ENABLE
static std::atomic_uintptr_t CompilationCounter;
std::string CompilationID = "rtc_" + std::to_string(CompilationCounter++);
return sycl::detail::jit_compiler::get_instance().compileSYCL(
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr,
RegisteredKernelNames);
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr);
#else
throw sycl::exception(sycl::errc::build,
"kernel_compiler via sycl-jit is not available");
Expand Down
8 changes: 3 additions & 5 deletions sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ std::string userArgsAsString(const std::vector<std::string> &UserArguments);
//
// Returns a pointer to the image (owned by the `jit_compiler` class), and the
// bundle-specific prefix used for loading the kernels.
std::pair<sycl_device_binaries, std::string>
SYCL_JIT_to_SPIRV(const std::string &Source,
const include_pairs_t &IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr,
const std::vector<std::string> &RegisteredKernelNames);
std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
const std::string &Source, const include_pairs_t &IncludePairs,
const std::vector<std::string> &UserArgs, std::string *LogPtr);

void SYCL_JIT_destroy(sycl_device_binaries Binaries);

Expand Down
5 changes: 5 additions & 0 deletions sycl/source/kernel_bundle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ kernel kernel_bundle_plain::ext_oneapi_get_kernel(detail::string_view name) {
return impl->ext_oneapi_get_kernel(name.data(), impl);
}

detail::string
kernel_bundle_plain::ext_oneapi_get_raw_kernel_name(detail::string_view name) {
return detail::string{impl->ext_oneapi_get_raw_kernel_name(name.data())};
}

//////////////////////////////////
///// sycl::detail free functions
//////////////////////////////////
Expand Down
4 changes: 4 additions & 0 deletions sycl/test-e2e/KernelCompiler/kernel_compiler_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ void test_build_and_run() {
assert(hasHerKernel && "her_kernel should exist, but doesn't");
assert(!notExistKernel && "non-existing kernel should NOT exist, but does?");

assert(
kbExe2.ext_oneapi_get_raw_kernel_name("my_kernel") == "my_kernel" &&
"source code name and compiler-generated name should match, but don't");

sycl::kernel my_kernel = kbExe2.ext_oneapi_get_kernel("my_kernel");
sycl::kernel her_kernel = kbExe2.ext_oneapi_get_kernel("her_kernel");

Expand Down
Loading