Skip to content

[NFC][SYCL] Pass context_impl by raw ptr/ref in device_image_impl.hpp #18981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions sycl/source/detail/device_image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ constexpr uint8_t ImageOriginKernelCompiler = 1 << 2;
class ManagedDeviceGlobalsRegistry {
public:
ManagedDeviceGlobalsRegistry(
const std::shared_ptr<context_impl> &ContextImpl,
const std::string &Prefix, std::vector<std::string> &&DeviceGlobalNames,
context_impl &ContextImpl, const std::string &Prefix,
std::vector<std::string> &&DeviceGlobalNames,
std::vector<std::unique_ptr<std::byte[]>> &&DeviceGlobalAllocations)
: MContextImpl{ContextImpl}, MPrefix{Prefix},
: MContextImpl{ContextImpl.shared_from_this()}, MPrefix{Prefix},
MDeviceGlobalNames{std::move(DeviceGlobalNames)},
MDeviceGlobalAllocations{std::move(DeviceGlobalAllocations)} {}

Expand Down Expand Up @@ -704,12 +704,11 @@ class device_image_impl {
assert(MRTCBinInfo);
assert(MOrigins & ImageOriginKernelCompiler);

const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);

for (const auto &SyclDev : Devices) {
device_impl &DevImpl = *getSyclObjImpl(SyclDev);
if (!ContextImpl->hasDevice(DevImpl)) {
if (!ContextImpl.hasDevice(DevImpl)) {
throw sycl::exception(make_error_code(errc::invalid),
"device not part of kernel_bundle context");
}
Expand Down Expand Up @@ -742,7 +741,7 @@ class device_image_impl {
Devices, BuildOptions, *SourceStrPtr, UrProgram);
}

const AdapterPtr &Adapter = ContextImpl->getAdapter();
const AdapterPtr &Adapter = ContextImpl.getAdapter();

if (!FetchedFromCache)
UrProgram = createProgramFromSource(Devices, BuildOptions, LogPtr);
Expand All @@ -752,7 +751,7 @@ class device_image_impl {
UrProgram, DeviceVec.size(), DeviceVec.data(), XsFlags.c_str());
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramBuild>(
ContextImpl->getHandleRef(), UrProgram, XsFlags.c_str());
ContextImpl.getHandleRef(), UrProgram, XsFlags.c_str());
}
Adapter->checkUrResult<errc::build>(Res);

Expand Down Expand Up @@ -796,12 +795,11 @@ class device_image_impl {
"compile is only available for kernel_bundle<bundle_state::source> "
"when the source language was sycl.");

std::shared_ptr<sycl::detail::context_impl> ContextImpl =
getSyclObjImpl(MContext);
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);

for (const auto &SyclDev : Devices) {
detail::device_impl &DevImpl = *getSyclObjImpl(SyclDev);
if (!ContextImpl->hasDevice(DevImpl)) {
if (!ContextImpl.hasDevice(DevImpl)) {
throw sycl::exception(make_error_code(errc::invalid),
"device not part of kernel_bundle context");
}
Expand Down Expand Up @@ -873,9 +871,8 @@ class device_image_impl {
const std::vector<device> Devices,
const std::vector<sycl::detail::string_view> &BuildOptions,
const std::string &SourceStr, ur_program_handle_t &UrProgram) const {
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();

std::string UserArgs = syclex::detail::userArgsAsString(BuildOptions);

Expand Down Expand Up @@ -904,7 +901,7 @@ class device_image_impl {
Properties.pMetadatas = nullptr;

Adapter->call<UrApiKind::urProgramCreateWithBinary>(
ContextImpl->getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
ContextImpl.getHandleRef(), DeviceHandles.size(), DeviceHandles.data(),
Lengths.data(), Binaries.data(), &Properties, &UrProgram);

return true;
Expand Down Expand Up @@ -1132,7 +1129,7 @@ class device_image_impl {
}

auto DGRegs = std::make_shared<ManagedDeviceGlobalsRegistry>(
getSyclObjImpl(MContext), std::string{Prefix},
*getSyclObjImpl(MContext), std::string{Prefix},
std::move(DeviceGlobalNames), std::move(DeviceGlobalAllocations));

// Mark the image as input so the program manager will bring it into
Expand Down Expand Up @@ -1195,9 +1192,8 @@ class device_image_impl {
createProgramFromSource(const std::vector<device> Devices,
const std::vector<sycl::detail::string_view> &Options,
std::string *LogPtr) const {
const std::shared_ptr<sycl::detail::context_impl> &ContextImpl =
getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl->getAdapter();
sycl::detail::context_impl &ContextImpl = *getSyclObjImpl(MContext);
const AdapterPtr &Adapter = ContextImpl.getAdapter();
const auto spirv = [&]() -> std::vector<uint8_t> {
switch (MRTCBinInfo->MLanguage) {
case syclex::source_language::opencl: {
Expand Down Expand Up @@ -1234,7 +1230,7 @@ class device_image_impl {
}();

ur_program_handle_t UrProgram = nullptr;
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl->getHandleRef(),
Adapter->call<UrApiKind::urProgramCreateWithIL>(ContextImpl.getHandleRef(),
spirv.data(), spirv.size(),
nullptr, &UrProgram);
// program created by urProgramCreateWithIL is implicitly retained.
Expand Down