Skip to content

Commit 3fce815

Browse files
committed
Add get_devices method to dpctl.SyclPlatform
1 parent 180daa5 commit 3fce815

File tree

5 files changed

+151
-2
lines changed

5 files changed

+151
-2
lines changed

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
313313
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
314314
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
315315
const DPCTLSyclPlatformRef)
316+
cdef DPCTLDeviceVectorRef DPCTLPlatform_GetDevices(
317+
const DPCTLSyclPlatformRef PRef, _device_type DTy)
316318

317319

318320
cdef extern from "syclinterface/dpctl_sycl_context_interface.h":

dpctl/_sycl_platform.pyx

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ from libcpp cimport bool
2626
from ._backend cimport ( # noqa: E211
2727
DPCTLCString_Delete,
2828
DPCTLDeviceSelector_Delete,
29+
DPCTLDeviceVector_Delete,
30+
DPCTLDeviceVector_GetAt,
31+
DPCTLDeviceVector_Size,
32+
DPCTLDeviceVectorRef,
2933
DPCTLFilterSelector_Create,
3034
DPCTLPlatform_AreEq,
3135
DPCTLPlatform_Copy,
@@ -34,6 +38,7 @@ from ._backend cimport ( # noqa: E211
3438
DPCTLPlatform_Delete,
3539
DPCTLPlatform_GetBackend,
3640
DPCTLPlatform_GetDefaultContext,
41+
DPCTLPlatform_GetDevices,
3742
DPCTLPlatform_GetName,
3843
DPCTLPlatform_GetPlatforms,
3944
DPCTLPlatform_GetVendor,
@@ -46,17 +51,21 @@ from ._backend cimport ( # noqa: E211
4651
DPCTLPlatformVector_Size,
4752
DPCTLPlatformVectorRef,
4853
DPCTLSyclContextRef,
54+
DPCTLSyclDeviceRef,
4955
DPCTLSyclDeviceSelectorRef,
5056
DPCTLSyclPlatformRef,
5157
_backend_type,
58+
_device_type,
5259
)
5360

5461
import warnings
5562

5663
from ._sycl_context import SyclContextCreationError
5764
from .enum_types import backend_type
65+
from .enum_types import device_type as device_type_t
5866

5967
from ._sycl_context cimport SyclContext
68+
from ._sycl_device cimport SyclDevice
6069

6170
__all__ = [
6271
"get_platforms",
@@ -366,6 +375,92 @@ cdef class SyclPlatform(_SyclPlatform):
366375
"""
367376
return DPCTLPlatform_Hash(self._platform_ref)
368377

378+
def get_devices(self, device_type=device_type_t.all):
379+
"""
380+
Returns the list of :class:`dpctl.SyclDevice` objects associated with
381+
:class:`dpctl.SyclPlatform` instance selected based on
382+
the given :class:`dpctl.device_type`.
383+
384+
Args:
385+
device_type (optional):
386+
A :class:`dpctl.device_type` enum value or a string that
387+
specifies a SYCL device type. Currently, accepted values are:
388+
"gpu", "cpu", "accelerator", "host", or "all".
389+
Default: ``dpctl.device_type.all``.
390+
391+
Returns:
392+
list:
393+
A :obj:`list` of :class:`dpctl.SyclDevice` objects
394+
that belong to this platform.
395+
396+
Raises:
397+
TypeError:
398+
If `device_type` is not a str or :class:`dpctl.device_type`
399+
enum.
400+
ValueError:
401+
If the value of `device_type` is not supported.
402+
403+
If the ``DPCTLPlatform_GetDevices`` call returned
404+
``NULL`` instead of a ``DPCTLDeviceVectorRef`` object.
405+
"""
406+
cdef _device_type DTy = _device_type._ALL_DEVICES
407+
cdef DPCTLDeviceVectorRef DVRef = NULL
408+
cdef size_t num_devs
409+
cdef size_t i
410+
cdef DPCTLSyclDeviceRef DRef
411+
412+
if isinstance(device_type, str):
413+
dty_str = device_type.strip().lower()
414+
if dty_str == "accelerator":
415+
DTy = _device_type._ACCELERATOR
416+
elif dty_str == "all":
417+
DTy = _device_type._ALL_DEVICES
418+
elif dty_str == "automatic":
419+
DTy = _device_type._AUTOMATIC
420+
elif dty_str == "cpu":
421+
DTy = _device_type._CPU
422+
elif dty_str == "custom":
423+
DTy = _device_type._CUSTOM
424+
elif dty_str == "gpu":
425+
DTy = _device_type._GPU
426+
else:
427+
raise ValueError(
428+
"Unexpected value of `device_type`."
429+
)
430+
elif isinstance(device_type, device_type_t):
431+
if device_type == device_type_t.all:
432+
DTy = _device_type._ALL_DEVICES
433+
elif device_type == device_type_t.accelerator:
434+
DTy = _device_type._ACCELERATOR
435+
elif device_type == device_type_t.automatic:
436+
DTy = _device_type._AUTOMATIC
437+
elif device_type == device_type_t.cpu:
438+
DTy = _device_type._CPU
439+
elif device_type == device_type_t.custom:
440+
DTy = _device_type._CUSTOM
441+
elif device_type == device_type_t.gpu:
442+
DTy = _device_type._GPU
443+
else:
444+
raise ValueError(
445+
"Unexpected value of `device_type`."
446+
)
447+
else:
448+
raise TypeError(
449+
"device type should be specified as a str or an "
450+
"``enum_types.device_type``."
451+
)
452+
DVRef = DPCTLPlatform_GetDevices(self.get_platform_ref(), DTy)
453+
if (DVRef is NULL):
454+
raise ValueError("Internal error: NULL device vector encountered")
455+
num_devs = DPCTLDeviceVector_Size(DVRef)
456+
devices = []
457+
for i in range(num_devs):
458+
DRef = DPCTLDeviceVector_GetAt(DVRef, i)
459+
devices.append(SyclDevice._create(DRef))
460+
DPCTLDeviceVector_Delete(DVRef)
461+
462+
return devices
463+
369464

370465
def lsplatform(verbosity=0):
371466
"""

libsyclinterface/include/syclinterface/dpctl_sycl_device_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ DPCTL_API
145145
size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier);
146146

147147
/*!
148-
* @brief Prints out the info::deivice attributes for the device that are
148+
* @brief Prints out the info::device attributes for the device that are
149149
* currently supported by dpctl.
150150
*
151151
* @param DRef A #DPCTLSyclDeviceRef opaque pointer.

libsyclinterface/include/syclinterface/dpctl_sycl_platform_interface.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "Support/ExternC.h"
3030
#include "Support/MemOwnershipAttrs.h"
3131
#include "dpctl_data_types.h"
32+
#include "dpctl_sycl_device_manager.h"
3233
#include "dpctl_sycl_enum_types.h"
3334
#include "dpctl_sycl_platform_manager.h"
3435
#include "dpctl_sycl_types.h"
@@ -176,6 +177,20 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);
176177
* @ingroup PlatformInterface
177178
*/
178179
DPCTL_API
179-
size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef PRef);
180+
size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef);
181+
182+
/*!
183+
* @brief Returns a vector of devices associated with sycl::platform referenced
184+
* by DPCTLSyclPlatformRef object.
185+
*
186+
* @param PRef The DPCTLSyclPlatformRef pointer.
187+
* @param DTy A DPCTLSyclDeviceType enum value.
188+
* @return A DPCTLDeviceVectorRef with devices associated with given PRef.
189+
* @ingroup PlatformInterface
190+
*/
191+
DPCTL_API
192+
__dpctl_give DPCTLDeviceVectorRef
193+
DPCTLPlatform_GetDevices(__dpctl_keep const DPCTLSyclPlatformRef PRef,
194+
DPCTLSyclDeviceType DTy);
180195

181196
DPCTL_C_EXTERN_C_END

libsyclinterface/source/dpctl_sycl_platform_interface.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "dpctl_device_selection.hpp"
3030
#include "dpctl_error_handlers.h"
3131
#include "dpctl_string_utils.hpp"
32+
#include "dpctl_sycl_enum_types.h"
3233
#include "dpctl_sycl_type_casters.hpp"
3334
#include "dpctl_utils_helper.h"
3435
#include <iomanip>
@@ -269,3 +270,39 @@ size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef)
269270
return 0;
270271
}
271272
}
273+
274+
__dpctl_give DPCTLDeviceVectorRef
275+
DPCTLPlatform_GetDevices(__dpctl_keep const DPCTLSyclPlatformRef PRef,
276+
DPCTLSyclDeviceType DTy)
277+
{
278+
auto P = unwrap<platform>(PRef);
279+
if (!P) {
280+
error_handler("Cannot retrieve devices from DPCTLSyclPlatformRef as "
281+
"input is a nullptr.",
282+
__FILE__, __func__, __LINE__);
283+
return nullptr;
284+
}
285+
using vecTy = std::vector<DPCTLSyclDeviceRef>;
286+
vecTy *DevicesVectorPtr = nullptr;
287+
try {
288+
DevicesVectorPtr = new vecTy();
289+
} catch (std::exception const &e) {
290+
delete DevicesVectorPtr;
291+
error_handler(e, __FILE__, __func__, __LINE__);
292+
return nullptr;
293+
}
294+
try {
295+
auto SyclDTy = DPCTL_DPCTLDeviceTypeToSyclDeviceType(DTy);
296+
auto Devices = P->get_devices(SyclDTy);
297+
DevicesVectorPtr->reserve(Devices.size());
298+
for (const auto &Dev : Devices) {
299+
DevicesVectorPtr->emplace_back(
300+
wrap<device>(new device(std::move(Dev))));
301+
}
302+
return wrap<vecTy>(DevicesVectorPtr);
303+
} catch (std::exception const &e) {
304+
delete DevicesVectorPtr;
305+
error_handler(e, __FILE__, __func__, __LINE__);
306+
return nullptr;
307+
}
308+
}

0 commit comments

Comments
 (0)