Skip to content

Commit 8d704db

Browse files
Merge pull request #588 from IntelPython/expand-pybind11-example
Used clang-format off, clang-format on to avoid include reordering
2 parents 225453f + 926dfaa commit 8d704db

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

examples/pybind11/use_dpctl_syclqueue/example.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@
2525

2626
# Pass dpctl.SyclQueue to Pybind11 extension
2727
eu_count = eg.get_max_compute_units(q)
28+
global_mem_size = eg.get_device_global_mem_size(q.sycl_device)
29+
local_mem_size = eg.get_device_local_mem_size(q.sycl_device)
2830

2931
print(f"EU count returned by Pybind11 extension {eu_count}")
3032
print("EU count computed by dpctl {}".format(q.sycl_device.max_compute_units))
33+
print("Device's global memory size: {} bytes".format(global_mem_size))
34+
print("Device's local memory size: {} bytes".format(local_mem_size))
3135

3236
print("")
3337
print("Computing modular reduction using SYCL on a NumPy array")

examples/pybind11/use_dpctl_syclqueue/pybind11_example.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
#include <pybind11/numpy.h>
44
#include <pybind11/pybind11.h>
55

6+
// clang-format off
7+
#include "dpctl_sycl_types.h"
68
#include "../_sycl_queue.h"
79
#include "../_sycl_queue_api.h"
8-
#include "dpctl_sycl_types.h"
10+
#include "../_sycl_device.h"
11+
#include "../_sycl_device_api.h"
12+
// clang-format on
913

1014
namespace py = pybind11;
1115

@@ -25,6 +29,34 @@ size_t get_max_compute_units(py::object queue)
2529
}
2630
}
2731

32+
uint64_t get_device_global_mem_size(py::object device)
33+
{
34+
PyObject *device_pycapi = device.ptr();
35+
if (PyObject_TypeCheck(device_pycapi, &PySyclDeviceType)) {
36+
DPCTLSyclDeviceRef DRef = get_device_ref(
37+
reinterpret_cast<PySyclDeviceObject *>(device_pycapi));
38+
sycl::device *d_ptr = reinterpret_cast<sycl::device *>(DRef);
39+
return d_ptr->get_info<sycl::info::device::global_mem_size>();
40+
}
41+
else {
42+
throw std::runtime_error("expected dpctl.SyclDevice as argument");
43+
}
44+
}
45+
46+
uint64_t get_device_local_mem_size(py::object device)
47+
{
48+
PyObject *device_pycapi = device.ptr();
49+
if (PyObject_TypeCheck(device_pycapi, &PySyclDeviceType)) {
50+
DPCTLSyclDeviceRef DRef = get_device_ref(
51+
reinterpret_cast<PySyclDeviceObject *>(device_pycapi));
52+
sycl::device *d_ptr = reinterpret_cast<sycl::device *>(DRef);
53+
return d_ptr->get_info<sycl::info::device::local_mem_size>();
54+
}
55+
else {
56+
throw std::runtime_error("expected dpctl.SyclDevice as argument");
57+
}
58+
}
59+
2860
py::array_t<int64_t>
2961
offloaded_array_mod(py::object queue,
3062
py::array_t<int64_t, py::array::c_style> array,
@@ -82,11 +114,16 @@ offloaded_array_mod(py::object queue,
82114

83115
PYBIND11_MODULE(pybind11_example, m)
84116
{
85-
// Import the dpctl._sycl_queue extension
117+
// Import the dpctl._sycl_queue, dpctl._sycl_device extensions
118+
import_dpctl___sycl_device();
86119
import_dpctl___sycl_queue();
87120
m.def("get_max_compute_units", &get_max_compute_units,
88121
"Computes max_compute_units property of the device underlying given "
89122
"dpctl.SyclQueue");
123+
m.def("get_device_global_mem_size", &get_device_global_mem_size,
124+
"Computes amount of global memory of the given dpctl.SyclDevice");
125+
m.def("get_device_local_mem_size", &get_device_local_mem_size,
126+
"Computes amount of local memory of the given dpctl.SyclDevice");
90127
m.def("offloaded_array_mod", &offloaded_array_mod,
91128
"Compute offloaded modular reduction of integer-valued NumPy array");
92129
}

0 commit comments

Comments
 (0)