Skip to content

Commit 6df9b9a

Browse files
author
JackAKirk
committed
Example implementation using getDeviceImpl.
Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
1 parent d3af00e commit 6df9b9a

File tree

6 files changed

+49
-17
lines changed

6 files changed

+49
-17
lines changed

sycl/include/CL/sycl/backend.hpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,6 @@ typename std::enable_if<
260260
detail::InteropFeatureSupportMap<Backend>::MakeDevice == true, device>::type
261261
make_device(const typename backend_traits<Backend>::template input_type<device>
262262
&BackendObject) {
263-
auto plts = platform::get_platforms();
264-
for (const auto &plt : plts) {
265-
if (plt.get_backend() == Backend) {
266-
auto devs = plt.get_devices(info::device_type::all);
267-
for (auto &dev : devs) {
268-
if (BackendObject == get_native<Backend>(dev)) {
269-
return dev;
270-
}
271-
}
272-
}
273-
}
274263
return detail::make_device(detail::pi::cast<pi_native_handle>(BackendObject),
275264
Backend);
276265
}

sycl/include/CL/sycl/platform.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@ class device_selector;
2626
class device;
2727
namespace detail {
2828
class platform_impl;
29-
}
29+
device make_device(pi_native_handle NativeHandle, backend Backend);
30+
} // namespace detail
3031

3132
/// Encapsulates a SYCL platform on which kernels may be executed.
3233
///
3334
/// \ingroup sycl_api
3435
class __SYCL_EXPORT platform {
3536
public:
37+
38+
friend device detail::make_device(pi_native_handle NativeHandle,
39+
backend Backend);
40+
3641
/// Constructs a SYCL platform as a host platform.
3742
platform();
3843

sycl/include/sycl/ext/oneapi/experimental/backend/cuda.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ interop_handle::get_native_context<backend::ext_oneapi_cuda>() const {
7575
template <>
7676
inline device make_device<backend::ext_oneapi_cuda>(
7777
const backend_input_t<backend::ext_oneapi_cuda, device> &BackendObject) {
78+
auto plts = platform::get_platforms();
79+
for (const auto &plt : plts) {
80+
if (plt.get_backend() == backend::ext_oneapi_cuda) {
81+
auto devs = plt.get_devices(info::device_type::all);
82+
for (auto &dev : devs) {
83+
if (BackendObject == get_native<backend::ext_oneapi_cuda>(dev)) {
84+
return dev;
85+
}
86+
}
87+
}
88+
}
7889
pi_native_handle NativeHandle = static_cast<pi_native_handle>(BackendObject);
7990
return ext::oneapi::cuda::make_device(NativeHandle);
8091
}

sycl/source/backend.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,21 @@ platform make_platform(pi_native_handle NativeHandle, backend Backend) {
5757

5858
__SYCL_EXPORT device make_device(pi_native_handle NativeHandle,
5959
backend Backend) {
60-
const auto &Plugin = getPlugin(Backend);
6160

62-
pi::PiDevice PiDevice = nullptr;
63-
Plugin.call<PiApiKind::piextDeviceCreateWithNativeHandle>(NativeHandle,
64-
nullptr, &PiDevice);
65-
// Construct the SYCL device from PI device.
61+
auto plts = platform::get_platforms();
62+
detail::pi::PiDevice PiDevice = nullptr;
63+
const auto &Plugin = detail::getPlugin(Backend);
64+
for (const auto &plt : plts) {
65+
if (plt.get_backend() == Backend) {
66+
Plugin.call<detail::PiApiKind::piextDeviceCreateWithNativeHandle>(
67+
NativeHandle, nullptr, &PiDevice);
68+
auto devImpl = plt.impl->getDeviceImpl(PiDevice, plt.impl);
69+
if (devImpl != nullptr) {
70+
return detail::createSyclObjFromImpl<device>(devImpl);
71+
}
72+
}
73+
}
74+
6675
return detail::createSyclObjFromImpl<device>(
6776
std::make_shared<device_impl>(PiDevice, Plugin));
6877
}

sycl/source/detail/platform_impl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ static void filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
212212
Plugin.setLastDeviceId(Platform, DeviceNum);
213213
}
214214

215+
std::shared_ptr<device_impl> platform_impl::getDeviceImpl(
216+
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
217+
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);
218+
219+
// If we've already seen this device, return the impl
220+
for (const std::weak_ptr<device_impl> &DeviceWP : MDeviceCache) {
221+
if (std::shared_ptr<device_impl> Device = DeviceWP.lock()) {
222+
if (Device->getHandleRef() == PiDevice)
223+
return Device;
224+
}
225+
}
226+
return nullptr;
227+
}
228+
215229
std::shared_ptr<device_impl> platform_impl::getOrMakeDeviceImpl(
216230
RT::PiDevice PiDevice, const std::shared_ptr<platform_impl> &PlatformImpl) {
217231
const std::lock_guard<std::mutex> Guard(MDeviceMapMutex);

sycl/source/detail/platform_impl.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ class platform_impl {
152152
getOrMakeDeviceImpl(RT::PiDevice PiDevice,
153153
const std::shared_ptr<platform_impl> &PlatformImpl);
154154

155+
std::shared_ptr<device_impl>
156+
getDeviceImpl(RT::PiDevice PiDevice,
157+
const std::shared_ptr<platform_impl> &PlatformImpl);
158+
155159
/// Static functions that help maintain platform uniquess and
156160
/// equality of comparison
157161

0 commit comments

Comments
 (0)