Skip to content

Commit 60bdc6d

Browse files
Exported keep_args_alive to dpctl4pybind11 and used that in tensor_py and onemkl_gemv example
1 parent 33065a2 commit 60bdc6d

File tree

3 files changed

+40
-56
lines changed

3 files changed

+40
-56
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
#include "dpctl_capi.h"
2929
#include <CL/sycl.hpp>
3030
#include <complex>
31+
#include <memory>
3132
#include <pybind11/pybind11.h>
33+
#include <vector>
3234

3335
namespace py = pybind11;
3436

@@ -497,4 +499,38 @@ class usm_ndarray : public py::object
497499
};
498500

499501
} // end namespace tensor
502+
503+
namespace utils
504+
{
505+
506+
template <std::size_t num>
507+
sycl::event keep_args_alive(sycl::queue q,
508+
const py::object (&py_objs)[num],
509+
const std::vector<sycl::event> &depends = {})
510+
{
511+
sycl::event host_task_ev = q.submit([&](sycl::handler &cgh) {
512+
cgh.depends_on(depends);
513+
std::array<std::shared_ptr<py::handle>, num> shp_arr;
514+
for (std::size_t i = 0; i < num; ++i) {
515+
shp_arr[i] = std::make_shared<py::handle>(py_objs[i]);
516+
shp_arr[i]->inc_ref();
517+
}
518+
cgh.host_task([=]() {
519+
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());
520+
if (guard) {
521+
PyGILState_STATE gstate;
522+
gstate = PyGILState_Ensure();
523+
for (std::size_t i = 0; i < num; ++i) {
524+
shp_arr[i]->dec_ref();
525+
}
526+
PyGILState_Release(gstate);
527+
}
528+
});
529+
});
530+
531+
return host_task_ev;
532+
}
533+
534+
} // end namespace utils
535+
500536
} // end namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -342,33 +342,7 @@ std::vector<py::ssize_t> f_contiguous_strides(int nd,
342342
}
343343
}
344344

345-
template <std::size_t num>
346-
sycl::event keep_args_alive(sycl::queue q,
347-
const py::object (&py_objs)[num],
348-
const std::vector<sycl::event> &depends = {})
349-
{
350-
sycl::event host_task_ev = q.submit([&](sycl::handler &cgh) {
351-
cgh.depends_on(depends);
352-
std::array<std::shared_ptr<py::handle>, num> shp_arr;
353-
for (std::size_t i = 0; i < num; ++i) {
354-
shp_arr[i] = std::make_shared<py::handle>(py_objs[i]);
355-
shp_arr[i]->inc_ref();
356-
}
357-
cgh.host_task([=]() {
358-
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());
359-
if (guard) {
360-
PyGILState_STATE gstate;
361-
gstate = PyGILState_Ensure();
362-
for (std::size_t i = 0; i < num; ++i) {
363-
shp_arr[i]->dec_ref();
364-
}
365-
PyGILState_Release(gstate);
366-
}
367-
});
368-
});
369-
370-
return host_task_ev;
371-
}
345+
using dpctl::utils::keep_args_alive;
372346

373347
void simplify_iteration_space(int &nd,
374348
const py::ssize_t *&shape,

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,7 @@
88

99
namespace py = pybind11;
1010

11-
sycl::event keep_args_alive(sycl::queue q,
12-
py::object o1,
13-
py::object o2,
14-
py::object o3,
15-
const std::vector<sycl::event> &depends = {})
16-
{
17-
sycl::event ht_event = q.submit([&](sycl::handler &cgh) {
18-
cgh.depends_on(depends);
19-
std::shared_ptr<py::handle> shp1 = std::make_shared<py::handle>(o1);
20-
std::shared_ptr<py::handle> shp2 = std::make_shared<py::handle>(o2);
21-
std::shared_ptr<py::handle> shp3 = std::make_shared<py::handle>(o3);
22-
shp1->inc_ref();
23-
shp2->inc_ref();
24-
shp3->inc_ref();
25-
cgh.host_task([=]() {
26-
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());
27-
if (guard) {
28-
PyGILState_STATE gstate;
29-
gstate = PyGILState_Ensure();
30-
shp1->dec_ref();
31-
shp2->dec_ref();
32-
shp3->dec_ref();
33-
PyGILState_Release(gstate);
34-
}
35-
});
36-
});
37-
return ht_event;
38-
}
11+
using dpctl::utils::keep_args_alive;
3912

4013
std::pair<sycl::event, sycl::event>
4114
gemv(sycl::queue q,
@@ -131,7 +104,8 @@ gemv(sycl::queue q,
131104
throw std::runtime_error("Type dispatch ran into trouble.");
132105
}
133106

134-
sycl::event ht_event = keep_args_alive(q, matrix, vector, result, {res_ev});
107+
sycl::event ht_event =
108+
keep_args_alive(q, {matrix, vector, result}, {res_ev});
135109

136110
return std::make_pair(ht_event, res_ev);
137111
}

0 commit comments

Comments
 (0)