@@ -16,11 +16,41 @@ namespace py = pybind11;
16
16
UsmNDArray_GetQueueRef
17
17
*/
18
18
19
- sycl::event gemv (sycl::queue q,
20
- py::object matrix,
21
- py::object vector,
22
- py::object result,
23
- const std::vector<sycl::event> &depends = {})
19
+ sycl::event keep_args_alive (sycl::queue q,
20
+ py::object o1,
21
+ py::object o2,
22
+ py::object o3,
23
+ const std::vector<sycl::event> &depends = {})
24
+ {
25
+ sycl::event ht_event = q.submit ([&](sycl::handler &cgh) {
26
+ cgh.depends_on (depends);
27
+ std::shared_ptr<py::handle> shp1 = std::make_shared<py::handle>(o1);
28
+ std::shared_ptr<py::handle> shp2 = std::make_shared<py::handle>(o2);
29
+ std::shared_ptr<py::handle> shp3 = std::make_shared<py::handle>(o3);
30
+ shp1->inc_ref ();
31
+ shp2->inc_ref ();
32
+ shp3->inc_ref ();
33
+ cgh.host_task ([=]() {
34
+ bool guard = (Py_IsInitialized () && !_Py_IsFinalizing ());
35
+ if (guard) {
36
+ PyGILState_STATE gstate;
37
+ gstate = PyGILState_Ensure ();
38
+ shp1->dec_ref ();
39
+ shp2->dec_ref ();
40
+ shp3->dec_ref ();
41
+ PyGILState_Release (gstate);
42
+ }
43
+ });
44
+ });
45
+ return ht_event;
46
+ }
47
+
48
+ std::pair<sycl::event, sycl::event>
49
+ gemv (sycl::queue q,
50
+ py::object matrix,
51
+ py::object vector,
52
+ py::object result,
53
+ const std::vector<sycl::event> &depends = {})
24
54
{
25
55
PyObject *m_src = matrix.ptr ();
26
56
if (!PyObject_TypeCheck (m_src, &PyUSMArrayType)) {
@@ -131,7 +161,9 @@ sycl::event gemv(sycl::queue q,
131
161
throw std::runtime_error (" Type dispatch ran into trouble." );
132
162
}
133
163
134
- return res_ev;
164
+ sycl::event ht_event = keep_args_alive (q, matrix, vector, result, {res_ev});
165
+
166
+ return std::make_pair (ht_event, res_ev);
135
167
}
136
168
137
169
PYBIND11_MODULE (_onemkl, m)
0 commit comments