@@ -64,16 +64,32 @@ static sycl::event div_impl(sycl::queue exec_q,
6464{
6565 type_utils::validate_type_for_device<T>(exec_q);
6666
67- const T* a = reinterpret_cast <const T*>(in_a);
68- const T* b = reinterpret_cast <const T*>(in_b);
69- T* y = reinterpret_cast <T*>(out_y);
67+ const T* _a = reinterpret_cast <const T*>(in_a);
68+ const T* _b = reinterpret_cast <const T*>(in_b);
69+ T* _y = reinterpret_cast <T*>(out_y);
7070
71- return mkl_vm::div (exec_q,
71+ T* a = sycl::malloc_device<T>(n, exec_q);
72+ T* b = sycl::malloc_device<T>(n, exec_q);
73+ T* y = sycl::malloc_device<T>(n, exec_q);
74+
75+ exec_q.copy (_a, a, n).wait ();
76+ exec_q.copy (_b, b, n).wait ();
77+ exec_q.copy (_y, y, n).wait ();
78+
79+ sycl::event ev = mkl_vm::div (exec_q,
7280 n, // number of elements to be calculated
7381 a, // pointer `a` containing 1st input vector of size n
7482 b, // pointer `b` containing 2nd input vector of size n
7583 y, // pointer `y` to the output vector of size n
7684 depends);
85+ ev.wait ();
86+
87+ exec_q.copy (y, _y, n).wait ();
88+
89+ sycl::free (a, exec_q);
90+ sycl::free (b, exec_q);
91+ sycl::free (y, exec_q);
92+ return sycl::event ();
7793}
7894
7995std::pair<sycl::event, sycl::event> div (sycl::queue exec_q,
@@ -175,9 +191,21 @@ std::pair<sycl::event, sycl::event> div(sycl::queue exec_q,
175191 throw py::value_error (" No div implementation defined" );
176192 }
177193 sycl::event sum_ev = div_fn (exec_q, src_nelems, src1_data, src2_data, dst_data, depends);
178-
179- sycl::event ht_ev = dpctl::utils::keep_args_alive (exec_q, {src1, src2, dst}, {sum_ev});
180- return std::make_pair (ht_ev, sum_ev);
194+ // sum_ev.wait();
195+
196+ // int* dummy = sycl::malloc_device<int>(1, exec_q);
197+ // sycl::event cleanup_ev = exec_q.submit([&](sycl::handler& cgh) {
198+ // // cgh.depends_on(sum_ev);
199+ // auto ctx = exec_q.get_context();
200+ // cgh.host_task([dummy, ctx]() {
201+ // // dummy host task to pass into keep_args_alive
202+ // sycl::free(dummy, ctx);
203+ // });
204+ // });
205+
206+ // sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, {sum_ev});
207+ // return std::make_pair(ht_ev, sum_ev);
208+ return std::make_pair (sycl::event (), sycl::event ());
181209}
182210
183211bool can_call_div (sycl::queue exec_q,
0 commit comments