Skip to content

Commit afc8821

Browse files
gemv function returns a pair of events, first for host task, second for MKL computation
1 parent cda9417 commit afc8821

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,41 @@ namespace py = pybind11;
1616
UsmNDArray_GetQueueRef
1717
*/
1818

19-
sycl::event gemv(sycl::queue q,
20-
py::object matrix,
21-
py::object vector,
22-
py::object result,
23-
const std::vector<sycl::event> &depends = {})
19+
sycl::event keep_args_alive(sycl::queue q,
20+
py::object o1,
21+
py::object o2,
22+
py::object o3,
23+
const std::vector<sycl::event> &depends = {})
24+
{
25+
sycl::event ht_event = q.submit([&](sycl::handler &cgh) {
26+
cgh.depends_on(depends);
27+
std::shared_ptr<py::handle> shp1 = std::make_shared<py::handle>(o1);
28+
std::shared_ptr<py::handle> shp2 = std::make_shared<py::handle>(o2);
29+
std::shared_ptr<py::handle> shp3 = std::make_shared<py::handle>(o3);
30+
shp1->inc_ref();
31+
shp2->inc_ref();
32+
shp3->inc_ref();
33+
cgh.host_task([=]() {
34+
bool guard = (Py_IsInitialized() && !_Py_IsFinalizing());
35+
if (guard) {
36+
PyGILState_STATE gstate;
37+
gstate = PyGILState_Ensure();
38+
shp1->dec_ref();
39+
shp2->dec_ref();
40+
shp3->dec_ref();
41+
PyGILState_Release(gstate);
42+
}
43+
});
44+
});
45+
return ht_event;
46+
}
47+
48+
std::pair<sycl::event, sycl::event>
49+
gemv(sycl::queue q,
50+
py::object matrix,
51+
py::object vector,
52+
py::object result,
53+
const std::vector<sycl::event> &depends = {})
2454
{
2555
PyObject *m_src = matrix.ptr();
2656
if (!PyObject_TypeCheck(m_src, &PyUSMArrayType)) {
@@ -131,7 +161,9 @@ sycl::event gemv(sycl::queue q,
131161
throw std::runtime_error("Type dispatch ran into trouble.");
132162
}
133163

134-
return res_ev;
164+
sycl::event ht_event = keep_args_alive(q, matrix, vector, result, {res_ev});
165+
166+
return std::make_pair(ht_event, res_ev);
135167
}
136168

137169
PYBIND11_MODULE(_onemkl, m)

examples/pybind11/onemkl_gemv/tests/test_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_gemv():
1515
r = dpt.empty((5,), dtype="d", sycl_queue=q)
1616
M = dpt.asarray(Mnp, sycl_queue=q)
1717
v = dpt.asarray(vnp, sycl_queue=q)
18-
ev = gemv(M.sycl_queue, M, v, r, [])
19-
ev.wait()
18+
hev, ev = gemv(M.sycl_queue, M, v, r, [])
19+
hev.wait()
2020
rnp = dpt.asnumpy(r)
2121
assert np.allclose(rnp, Mnp @ vnp)

0 commit comments

Comments
 (0)