@@ -64,46 +64,16 @@ static sycl::event div_impl(sycl::queue exec_q,
6464{
6565 type_utils::validate_type_for_device<T>(exec_q);
6666
67- std::cerr << " enter div_impl" << std::endl;
67+ const T* a = reinterpret_cast <const T*>(in_a);
68+ const T* b = reinterpret_cast <const T*>(in_b);
69+ T* y = reinterpret_cast <T*>(out_y);
6870
69- const T* _a = reinterpret_cast <const T*>(in_a);
70- const T* _b = reinterpret_cast <const T*>(in_b);
71- T* _y = reinterpret_cast <T*>(out_y);
72-
73- std::cerr << " casting is done" << std::endl;
74-
75- T* a = sycl::malloc_device<T>(n, exec_q);
76- T* b = sycl::malloc_device<T>(n, exec_q);
77- T* y = sycl::malloc_device<T>(n, exec_q);
78-
79- std::cerr << " malloc is done" << std::endl;
80-
81- exec_q.copy (_a, a, n).wait ();
82- exec_q.copy (_b, b, n).wait ();
83- exec_q.copy (_y, y, n).wait ();
84-
85- std::cerr << " copy is done" << std::endl;
86-
87- sycl::event ev = mkl_vm::div (exec_q,
71+ return mkl_vm::div (exec_q,
8872 n, // number of elements to be calculated
8973 a, // pointer `a` containing 1st input vector of size n
9074 b, // pointer `b` containing 2nd input vector of size n
9175 y, // pointer `y` to the output vector of size n
9276 depends);
93- ev.wait ();
94-
95- std::cerr << " div is done" << std::endl;
96-
97- exec_q.copy (y, _y, n).wait ();
98-
99- std::cerr << " copy is done" << std::endl;
100-
101- sycl::free (a, exec_q);
102- sycl::free (b, exec_q);
103- sycl::free (y, exec_q);
104-
105- std::cerr << " leaving div_impl" << std::endl;
106- return sycl::event ();
10777}
10878
10979std::pair<sycl::event, sycl::event> div (sycl::queue exec_q,
@@ -205,20 +175,9 @@ std::pair<sycl::event, sycl::event> div(sycl::queue exec_q,
205175 throw py::value_error (" No div implementation defined" );
206176 }
207177 sycl::event sum_ev = div_fn (exec_q, src_nelems, src1_data, src2_data, dst_data, depends);
208- // sum_ev.wait();
209-
210- // int* dummy = sycl::malloc_device<int>(1, exec_q);
211- // sycl::event cleanup_ev = exec_q.submit([&](sycl::handler& cgh) {
212- // // cgh.depends_on(sum_ev);
213- // auto ctx = exec_q.get_context();
214- // cgh.host_task([dummy, ctx]() {
215- // // dummy host task to pass into keep_args_alive
216- // sycl::free(dummy, ctx);
217- // });
218- // });
219-
220- // sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, {sum_ev});
221- // return std::make_pair(ht_ev, sum_ev);
178+
179+ sycl::event ht_ev = dpctl::utils::keep_args_alive (exec_q, {src1, src2, dst}, {sum_ev});
180+ return std::make_pair (ht_ev, sum_ev);
222181 return std::make_pair (sycl::event (), sycl::event ());
223182}
224183
@@ -227,6 +186,7 @@ bool can_call_div(sycl::queue exec_q,
227186 dpctl::tensor::usm_ndarray src2,
228187 dpctl::tensor::usm_ndarray dst)
229188{
189+ #if INTEL_MKL_VERSION >= 20230002
230190 // check type_nums
231191 int src1_typenum = src1.get_typenum ();
232192 int src2_typenum = src2.get_typenum ();
@@ -325,6 +285,16 @@ bool can_call_div(sycl::queue exec_q,
325285 return false ;
326286 }
327287 return true ;
288+ #else
289+ // In OneMKL 2023.1.0 the call of oneapi::mkl::vm::div() is going to dead lock
290+ // inside ~usm_wrapper_to_host()->{...; q_->wait_and_throw(); ...}
291+
292+ (void )exec_q;
293+ (void )src1;
294+ (void )src2;
295+ (void )dst;
296+ return false ;
297+ #endif // INTEL_MKL_VERSION >= 20230002
328298}
329299
330300template <typename fnT, typename T>
0 commit comments