Skip to content

[SYCL][NFC] Pass adapter by reference in backend[_impl] #19187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions sycl/source/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

static const AdapterPtr &getAdapter(backend Backend) {
static const adapter_impl &getAdapter(backend Backend) {
switch (Backend) {
case backend::opencl:
return ur::getAdapter<backend::opencl>();
return *ur::getAdapter<backend::opencl>();
case backend::ext_oneapi_level_zero:
return ur::getAdapter<backend::ext_oneapi_level_zero>();
return *ur::getAdapter<backend::ext_oneapi_level_zero>();
case backend::ext_oneapi_cuda:
return ur::getAdapter<backend::ext_oneapi_cuda>();
return *ur::getAdapter<backend::ext_oneapi_cuda>();
case backend::ext_oneapi_hip:
return ur::getAdapter<backend::ext_oneapi_hip>();
return *ur::getAdapter<backend::ext_oneapi_hip>();
default:
throw sycl::exception(
sycl::make_error_code(sycl::errc::runtime),
Expand Down Expand Up @@ -71,24 +71,24 @@ backend convertUrBackend(ur_backend_t UrBackend) {
}

platform make_platform(ur_native_handle_t NativeHandle, backend Backend) {
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);

// Create UR platform first.
ur_platform_handle_t UrPlatform = nullptr;
Adapter->call<UrApiKind::urPlatformCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrPlatform);
Adapter.call<UrApiKind::urPlatformCreateWithNativeHandle>(
NativeHandle, Adapter.getUrAdapter(), nullptr, &UrPlatform);

return detail::createSyclObjFromImpl<platform>(
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter));
}

__SYCL_EXPORT device make_device(ur_native_handle_t NativeHandle,
backend Backend) {
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);

ur_device_handle_t UrDevice = nullptr;
Adapter->call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), nullptr, &UrDevice);
Adapter.call<UrApiKind::urDeviceCreateWithNativeHandle>(
NativeHandle, Adapter.getUrAdapter(), nullptr, &UrDevice);

// Construct the SYCL device from UR device.
return detail::createSyclObjFromImpl<device>(
Expand All @@ -100,7 +100,7 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
const async_handler &Handler,
backend Backend, bool KeepOwnership,
const std::vector<device> &DeviceList) {
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);

ur_context_handle_t UrContext = nullptr;
ur_context_native_properties_t Properties{};
Expand All @@ -110,8 +110,8 @@ __SYCL_EXPORT context make_context(ur_native_handle_t NativeHandle,
for (const auto &Dev : DeviceList) {
DeviceHandles.push_back(detail::getSyclObjImpl(Dev)->getHandleRef());
}
Adapter->call<UrApiKind::urContextCreateWithNativeHandle>(
NativeHandle, Adapter->getUrAdapter(), DeviceHandles.size(),
Adapter.call<UrApiKind::urContextCreateWithNativeHandle>(
NativeHandle, Adapter.getUrAdapter(), DeviceHandles.size(),
DeviceHandles.data(), &Properties, &UrContext);
// Construct the SYCL context from UR context.
return detail::createSyclObjFromImpl<context>(context_impl::create(
Expand All @@ -125,7 +125,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
const async_handler &Handler, backend Backend) {
ur_device_handle_t UrDevice =
Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr;
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);
context_impl &ContextImpl = *getSyclObjImpl(Context);

if (PropList.has_property<ext::intel::property::queue::compute_index>()) {
Expand Down Expand Up @@ -155,7 +155,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
// Create UR queue first.
ur_queue_handle_t UrQueue = nullptr;

Adapter->call<UrApiKind::urQueueCreateWithNativeHandle>(
Adapter.call<UrApiKind::urQueueCreateWithNativeHandle>(
NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties,
&UrQueue);
// Construct the SYCL queue from UR queue.
Expand All @@ -171,15 +171,15 @@ __SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle,
__SYCL_EXPORT event make_event(ur_native_handle_t NativeHandle,
const context &Context, bool KeepOwnership,
backend Backend) {
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(Context);

ur_event_handle_t UrEvent = nullptr;
ur_event_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Adapter->call<UrApiKind::urEventCreateWithNativeHandle>(
Adapter.call<UrApiKind::urEventCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrEvent);
event Event = detail::createSyclObjFromImpl<event>(
event_impl::create_from_handle(UrEvent, Context));
Expand All @@ -193,15 +193,15 @@ std::shared_ptr<detail::kernel_bundle_impl>
make_kernel_bundle(ur_native_handle_t NativeHandle,
const context &TargetContext, bool KeepOwnership,
bundle_state State, backend Backend) {
const auto &Adapter = getAdapter(Backend);
const adapter_impl &Adapter = getAdapter(Backend);
const auto &ContextImpl = getSyclObjImpl(TargetContext);

ur_program_handle_t UrProgram = nullptr;
ur_program_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;

Adapter->call<UrApiKind::urProgramCreateWithNativeHandle>(
Adapter.call<UrApiKind::urProgramCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), &Properties, &UrProgram);
if (UrProgram == nullptr)
throw sycl::exception(
Expand All @@ -214,39 +214,39 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
std::vector<ur_device_handle_t> ProgramDevices;
uint32_t NumDevices = 0;

Adapter->call<UrApiKind::urProgramGetInfo>(
Adapter.call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(NumDevices), &NumDevices,
nullptr);
ProgramDevices.resize(NumDevices);
Adapter->call<UrApiKind::urProgramGetInfo>(
Adapter.call<UrApiKind::urProgramGetInfo>(
UrProgram, UR_PROGRAM_INFO_DEVICES,
sizeof(ur_device_handle_t) * NumDevices, ProgramDevices.data(), nullptr);

for (auto &Dev : ProgramDevices) {
ur_program_binary_type_t BinaryType;
Adapter->call<UrApiKind::urProgramGetBuildInfo>(
Adapter.call<UrApiKind::urProgramGetBuildInfo>(
UrProgram, Dev, UR_PROGRAM_BUILD_INFO_BINARY_TYPE,
sizeof(ur_program_binary_type_t), &BinaryType, nullptr);
switch (BinaryType) {
case (UR_PROGRAM_BINARY_TYPE_NONE):
if (State == bundle_state::object) {
auto Res = Adapter->call_nocheck<UrApiKind::urProgramCompileExp>(
auto Res = Adapter.call_nocheck<UrApiKind::urProgramCompileExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramCompile>(
Res = Adapter.call_nocheck<UrApiKind::urProgramCompile>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Adapter->checkUrResult<errc::build>(Res);
Adapter.checkUrResult<errc::build>(Res);
}

else if (State == bundle_state::executable) {
auto Res = Adapter->call_nocheck<UrApiKind::urProgramBuildExp>(
auto Res = Adapter.call_nocheck<UrApiKind::urProgramBuildExp>(
UrProgram, 1, &Dev, nullptr);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramBuild>(
Res = Adapter.call_nocheck<UrApiKind::urProgramBuild>(
ContextImpl->getHandleRef(), UrProgram, nullptr);
}
Adapter->checkUrResult<errc::build>(Res);
Adapter.checkUrResult<errc::build>(Res);
}

break;
Expand All @@ -259,15 +259,15 @@ make_kernel_bundle(ur_native_handle_t NativeHandle,
detail::codeToString(UR_RESULT_ERROR_INVALID_VALUE));
if (State == bundle_state::executable) {
ur_program_handle_t UrLinkedProgram = nullptr;
auto Res = Adapter->call_nocheck<UrApiKind::urProgramLinkExp>(
auto Res = Adapter.call_nocheck<UrApiKind::urProgramLinkExp>(
ContextImpl->getHandleRef(), 1, &Dev, 1, &UrProgram, nullptr,
&UrLinkedProgram);
if (Res == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
Res = Adapter->call_nocheck<UrApiKind::urProgramLink>(
Res = Adapter.call_nocheck<UrApiKind::urProgramLink>(
ContextImpl->getHandleRef(), 1, &UrProgram, nullptr,
&UrLinkedProgram);
}
Adapter->checkUrResult<errc::build>(Res);
Adapter.checkUrResult<errc::build>(Res);
if (UrLinkedProgram != nullptr) {
UrProgram = UrLinkedProgram;
}
Expand Down Expand Up @@ -351,7 +351,7 @@ kernel make_kernel(const context &TargetContext,
ur_kernel_native_properties_t Properties{};
Properties.stype = UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES;
Properties.isNativeHandleOwned = !KeepOwnership;
Adapter->call<UrApiKind::urKernelCreateWithNativeHandle>(
Adapter.call<UrApiKind::urKernelCreateWithNativeHandle>(
NativeHandle, ContextImpl->getHandleRef(), UrProgram, &Properties,
&UrKernel);

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ context::context(cl_context ClContext, async_handler AsyncHandler) {
Adapter->call<detail::UrApiKind::urContextCreateWithNativeHandle>(
nativeHandle, Adapter->getUrAdapter(), 0, nullptr, nullptr, &hContext);

impl = detail::context_impl::create(hContext, AsyncHandler, Adapter);
impl = detail::context_impl::create(hContext, AsyncHandler, *Adapter);
}

template <typename Param>
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/adapter_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class adapter_impl {
return UrPlatforms;
}

ur_adapter_handle_t getUrAdapter() { return MAdapter; }
ur_adapter_handle_t getUrAdapter() const { return MAdapter; }

/// Calls the UR Api, traces the call, and returns the result.
///
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
// Get platform's backend and put it to DeviceDesc
DeviceDescT DeviceDesc;
platform_impl &PlatformImpl =
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter);
backend Backend = PlatformImpl.getBackend();

for (const auto &SyclBe : getSyclBeMap()) {
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,

context_impl::context_impl(ur_context_handle_t UrContext,
async_handler AsyncHandler,
const AdapterPtr &Adapter,
const adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList,
bool OwnedByRuntime, private_tag)
: MOwnedByRuntime(OwnedByRuntime), MAsyncHandler(AsyncHandler),
Expand All @@ -74,12 +74,12 @@ context_impl::context_impl(ur_context_handle_t UrContext,
std::vector<ur_device_handle_t> DeviceIds;
uint32_t DevicesNum = 0;
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urContextGetInfo>(
Adapter.call<UrApiKind::urContextGetInfo>(
MContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(DevicesNum), &DevicesNum,
nullptr);
DeviceIds.resize(DevicesNum);
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urContextGetInfo>(
Adapter.call<UrApiKind::urContextGetInfo>(
MContext, UR_CONTEXT_INFO_DEVICES,
sizeof(ur_device_handle_t) * DevicesNum, &DeviceIds[0], nullptr);

Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
/// \param OwnedByRuntime is the flag if ownership is kept by user or
/// transferred to runtime
context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const AdapterPtr &Adapter,
const adapter_impl &Adapter,
const std::vector<sycl::device> &DeviceList, bool OwnedByRuntime,
private_tag);

context_impl(ur_context_handle_t UrContext, async_handler AsyncHandler,
const AdapterPtr &Adapter, private_tag tag)
const adapter_impl &Adapter, private_tag tag)
: context_impl(UrContext, AsyncHandler, Adapter,
std::vector<sycl::device>{},
/*OwnedByRuntime*/ true, tag) {}
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/device_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
CASE(info::device::platform) {
return createSyclObjFromImpl<platform>(
platform_impl::getOrMakePlatformImpl(
get_info_impl<UR_DEVICE_INFO_PLATFORM>(), getAdapter()));
get_info_impl<UR_DEVICE_INFO_PLATFORM>(), *getAdapter()));
}

CASE(info::device::profile) {
Expand Down
16 changes: 8 additions & 8 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace detail {

platform_impl &
platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
const AdapterPtr &Adapter) {
const adapter_impl &Adapter) {
std::shared_ptr<platform_impl> Result;
{
const std::lock_guard<std::mutex> Guard(
Expand All @@ -50,8 +50,8 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
// Otherwise make the impl. Our ctor/dtor are private, so std::make_shared
// needs a bit of help...
struct creator : platform_impl {
creator(ur_platform_handle_t APlatform, const AdapterPtr &AAdapter)
: platform_impl(APlatform, AAdapter) {}
creator(ur_platform_handle_t APlatform, const adapter_impl &AAdapter)
: platform_impl(APlatform, &AAdapter) {}
};
Result = std::make_shared<creator>(UrPlatform, Adapter);
PlatformCache.emplace_back(Result);
Expand All @@ -62,12 +62,12 @@ platform_impl::getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,

platform_impl &
platform_impl::getPlatformFromUrDevice(ur_device_handle_t UrDevice,
const AdapterPtr &Adapter) {
const adapter_impl &Adapter) {
ur_platform_handle_t Plt =
nullptr; // TODO catch an exception and put it to list
// of asynchronous exceptions
Adapter->call<UrApiKind::urDeviceGetInfo>(UrDevice, UR_DEVICE_INFO_PLATFORM,
sizeof(Plt), &Plt, nullptr);
Adapter.call<UrApiKind::urDeviceGetInfo>(UrDevice, UR_DEVICE_INFO_PLATFORM,
sizeof(Plt), &Plt, nullptr);
return getOrMakePlatformImpl(Plt, Adapter);
}

Expand Down Expand Up @@ -131,7 +131,7 @@ std::vector<platform> platform_impl::getAdapterPlatforms(AdapterPtr &Adapter,

for (const auto &UrPlatform : UrPlatforms) {
platform Platform = detail::createSyclObjFromImpl<platform>(
getOrMakePlatformImpl(UrPlatform, Adapter));
getOrMakePlatformImpl(UrPlatform, *Adapter));
const bool IsBanned = IsBannedPlatform(Platform);
bool HasAnyDevices = false;

Expand Down Expand Up @@ -543,7 +543,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {

// The next step is to inflate the filtered UrDevices into SYCL Device
// objects.
platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, MAdapter);
platform_impl &PlatformImpl = getOrMakePlatformImpl(MPlatform, *MAdapter);
std::transform(UrDevices.begin(), UrDevices.end(), std::back_inserter(Res),
[&PlatformImpl](const ur_device_handle_t UrDevice) -> device {
return detail::createSyclObjFromImpl<device>(
Expand Down
21 changes: 8 additions & 13 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
//
// Platforms can only be created under `GlobalHandler`'s ownership via
// `platform_impl::getOrMakePlatformImpl` method.
explicit platform_impl(ur_platform_handle_t APlatform, adapter_impl *AAdapter)
: MPlatform(APlatform), MAdapter(AAdapter) {
explicit platform_impl(ur_platform_handle_t APlatform,
const adapter_impl *AAdapter)
: MPlatform(APlatform) {

MAdapter = const_cast<AdapterPtr>(AAdapter);

// Find out backend of the platform
ur_backend_t UrBackend = UR_BACKEND_UNKNOWN;
AAdapter->call_nocheck<UrApiKind::urPlatformGetInfo>(
Expand Down Expand Up @@ -137,15 +141,6 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
// \return the Adapter associated with this platform.
const AdapterPtr &getAdapter() const { return MAdapter; }

/// Sets the platform implementation to use another adapter.
///
/// \param AdapterPtr is a pointer to a adapter instance
/// \param Backend is the backend that we want this platform to use
void setAdapter(AdapterPtr &AdapterPtr, backend Backend) {
MAdapter = AdapterPtr;
MBackend = Backend;
}

/// Gets the native handle of the SYCL platform.
///
/// \return a native handle.
Expand Down Expand Up @@ -188,7 +183,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
/// \param Adapter is the UR adapter providing the backend for the platform
/// \return the platform_impl representing the UR platform
static platform_impl &getOrMakePlatformImpl(ur_platform_handle_t UrPlatform,
const AdapterPtr &Adapter);
const adapter_impl &Adapter);

/// Queries the cache for the specified platform based on an input device.
/// If found, returns the the cached platform_impl, otherwise creates a new
Expand All @@ -200,7 +195,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
/// platform
/// \return the platform_impl that contains the input device
static platform_impl &getPlatformFromUrDevice(ur_device_handle_t UrDevice,
const AdapterPtr &Adapter);
const adapter_impl &Adapter);

context_impl &khr_get_default_context();

Expand Down
2 changes: 1 addition & 1 deletion sycl/source/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ device::device(cl_device_id DeviceId) {
Adapter->call<detail::UrApiKind::urDeviceCreateWithNativeHandle>(
detail::ur::cast<ur_native_handle_t>(DeviceId), Adapter->getUrAdapter(),
nullptr, &Device);
impl = detail::platform_impl::getPlatformFromUrDevice(Device, Adapter)
impl = detail::platform_impl::getPlatformFromUrDevice(Device, *Adapter)
.getOrMakeDeviceImpl(Device)
.shared_from_this();
__SYCL_OCL_CALL(clRetainDevice, DeviceId);
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ platform::platform(cl_platform_id PlatformId) {
Adapter->call<detail::UrApiKind::urPlatformCreateWithNativeHandle>(
detail::ur::cast<ur_native_handle_t>(PlatformId), Adapter->getUrAdapter(),
/* pProperties = */ nullptr, &UrPlatform);
impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter)
impl = detail::platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter)
.shared_from_this();
}

Expand Down
Loading