Skip to content

[NFC][SYCL] Ensure context_impl is always created via std::make_shared #18795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
NativeHandle, Adapter->getUrAdapter(), DeviceHandles.size(),
DeviceHandles.data(), &Properties, &UrContext);
// Construct the SYCL context from UR context.
return detail::createSyclObjFromImpl<context>(std::make_shared<context_impl>(
return detail::createSyclObjFromImpl<context>(context_impl::create(
UrContext, Handler, Adapter, DeviceList, !KeepOwnership));
}

Expand Down
6 changes: 2 additions & 4 deletions sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ context::context(const std::vector<device> &DeviceList,
throw exception(make_error_code(errc::invalid),
"Can't add devices across platforms to a single context.");
else
impl = std::make_shared<detail::context_impl>(DeviceList, AsyncHandler,
PropList);
impl = detail::context_impl::create(DeviceList, AsyncHandler, PropList);
}
context::context(cl_context ClContext, async_handler AsyncHandler) {
const auto &Adapter = sycl::detail::ur::getAdapter<backend::opencl>();
Expand All @@ -81,8 +80,7 @@ context::context(cl_context ClContext, async_handler AsyncHandler) {
Adapter->call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext);

impl =
std::make_shared<detail::context_impl>(hContext, AsyncHandler, Adapter);
impl = detail::context_impl::create(hContext, AsyncHandler, Adapter);
}

template <typename Param>
Expand Down
14 changes: 2 additions & 12 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,9 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

context_impl::context_impl(const device &Device, async_handler AsyncHandler,
const property_list &PropList)
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(1, Device),
MContext(nullptr),
MPlatform(detail::getSyclObjImpl(Device.get_platform())),
MPropList(PropList), MSupportBufferLocationByDevices(NotChecked) {
verifyProps(PropList);
MKernelProgramCache.setContextPtr(this);
}

context_impl::context_impl(const std::vector<sycl::device> Devices,
async_handler AsyncHandler,
const property_list &PropList)
const property_list &PropList, private_tag)
: MOwnedByRuntime(true), MAsyncHandler(AsyncHandler), MDevices(Devices),
MContext(nullptr),
MPlatform(detail::getSyclObjImpl(MDevices[0].get_platform())),
Expand Down Expand Up @@ -72,7 +62,7 @@ context_impl::context_impl(ur_context_handle_t UrContext,
async_handler AsyncHandler,
const AdapterPtr &Adapter,
const std::vector<sycl::device> &DeviceList,
bool OwnedByRuntime)
bool OwnedByRuntime, private_tag)
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler),
MDevices(DeviceList), MContext(UrContext), MPlatform(),
MSupportBufferLocationByDevices(NotChecked) {
Expand Down
40 changes: 24 additions & 16 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,12 @@ inline namespace _V1 {
// Forward declaration
class device;
namespace detail {
class context_impl {
public:
/// Constructs a context_impl using a single SYCL devices.
///
/// The constructed context_impl will use the AsyncHandler parameter to
/// handle exceptions.
/// PropList carries the properties of the constructed context_impl.
///
/// \param Device is an instance of SYCL device.
/// \param AsyncHandler is an instance of async_handler.
/// \param PropList is an instance of property_list.
context_impl(const device &Device, async_handler AsyncHandler,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unused before this PR.

const property_list &PropList);
class context_impl : std::enable_shared_from_this<context_impl> {
struct private_tag {
explicit private_tag() = default;
};

public:
/// Constructs a context_impl using a list of SYCL devices.
///
/// Newly created instance will save each SYCL device in the list. This
Expand All @@ -56,7 +48,8 @@ class context_impl {
/// \param AsyncHandler is an instance of async_handler.
/// \param PropList is an instance of property_list.
context_impl(const std::vector<sycl::device> DeviceList,
async_handler AsyncHandler, const property_list &PropList);
async_handler AsyncHandler, const property_list &PropList,
private_tag);

/// Construct a context_impl using plug-in interoperability handle.
///
Expand All @@ -70,8 +63,23 @@ class context_impl {
/// transferred to runtime
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const AdapterPtr &Adapter,
const std::vector<sycl::device> &DeviceList = {},
bool OwnedByRuntime = true);
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
private_tag);

context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const AdapterPtr &Adapter, private_tag tag)
: context_impl(UrContext, AsyncHandler, Adapter,
std::vector<sycl::device>{},
/*OwnedByRuntime*/ true, tag) {}

// Single variadic method works because all the ctors are expected to be
// "public" except the `private_tag` part restricting the creation to
// `std::shared_ptr` allocations.
template <typename... Ts>
static std::shared_ptr<context_impl> create(Ts &&...args) {
return std::make_shared<context_impl>(std::forward<Ts>(args)...,
private_tag{});
}

~context_impl();

Expand Down