3
3
#include < pybind11/numpy.h>
4
4
#include < pybind11/pybind11.h>
5
5
6
+ // clang-format off
7
+ #include " dpctl_sycl_types.h"
6
8
#include " ../_sycl_queue.h"
7
9
#include " ../_sycl_queue_api.h"
8
- #include " dpctl_sycl_types.h"
10
+ #include " ../_sycl_device.h"
11
+ #include " ../_sycl_device_api.h"
12
+ // clang-format on
9
13
10
14
namespace py = pybind11;
11
15
@@ -25,6 +29,34 @@ size_t get_max_compute_units(py::object queue)
25
29
}
26
30
}
27
31
32
+ uint64_t get_device_global_mem_size (py::object device)
33
+ {
34
+ PyObject *device_pycapi = device.ptr ();
35
+ if (PyObject_TypeCheck (device_pycapi, &PySyclDeviceType)) {
36
+ DPCTLSyclDeviceRef DRef = get_device_ref (
37
+ reinterpret_cast <PySyclDeviceObject *>(device_pycapi));
38
+ sycl::device *d_ptr = reinterpret_cast <sycl::device *>(DRef);
39
+ return d_ptr->get_info <sycl::info::device::global_mem_size>();
40
+ }
41
+ else {
42
+ throw std::runtime_error (" expected dpctl.SyclDevice as argument" );
43
+ }
44
+ }
45
+
46
+ uint64_t get_device_local_mem_size (py::object device)
47
+ {
48
+ PyObject *device_pycapi = device.ptr ();
49
+ if (PyObject_TypeCheck (device_pycapi, &PySyclDeviceType)) {
50
+ DPCTLSyclDeviceRef DRef = get_device_ref (
51
+ reinterpret_cast <PySyclDeviceObject *>(device_pycapi));
52
+ sycl::device *d_ptr = reinterpret_cast <sycl::device *>(DRef);
53
+ return d_ptr->get_info <sycl::info::device::local_mem_size>();
54
+ }
55
+ else {
56
+ throw std::runtime_error (" expected dpctl.SyclDevice as argument" );
57
+ }
58
+ }
59
+
28
60
py::array_t <int64_t >
29
61
offloaded_array_mod (py::object queue,
30
62
py::array_t <int64_t , py::array::c_style> array,
@@ -82,11 +114,16 @@ offloaded_array_mod(py::object queue,
82
114
83
115
PYBIND11_MODULE (pybind11_example, m)
84
116
{
85
- // Import the dpctl._sycl_queue extension
117
+ // Import the dpctl._sycl_queue, dpctl._sycl_device extensions
118
+ import_dpctl___sycl_device ();
86
119
import_dpctl___sycl_queue ();
87
120
m.def (" get_max_compute_units" , &get_max_compute_units,
88
121
" Computes max_compute_units property of the device underlying given "
89
122
" dpctl.SyclQueue" );
123
+ m.def (" get_device_global_mem_size" , &get_device_global_mem_size,
124
+ " Computes amount of global memory of the given dpctl.SyclDevice" );
125
+ m.def (" get_device_local_mem_size" , &get_device_local_mem_size,
126
+ " Computes amount of local memory of the given dpctl.SyclDevice" );
90
127
m.def (" offloaded_array_mod" , &offloaded_array_mod,
91
128
" Compute offloaded modular reduction of integer-valued NumPy array" );
92
129
}
0 commit comments