@@ -238,38 +238,30 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
238
238
239
239
auto fn = strided_dispatch_vector[src_typeid];
240
240
241
+ // using a single host_task for packing here
242
+ // prevents crashes on CPU
241
243
std::vector<sycl::event> host_task_events{};
242
- const auto &iter_src_dst_metadata_packing_triple_ =
244
+ const auto &iter_red_metadata_packing_triple_ =
243
245
dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
244
246
exec_q, host_task_events, simplified_iter_shape,
245
- simplified_iter_src_strides, simplified_iter_dst_strides);
246
- py::ssize_t *iter_shape_and_strides =
247
- std::get<0 >(iter_src_dst_metadata_packing_triple_);
248
- if (iter_shape_and_strides == nullptr ) {
247
+ simplified_iter_src_strides, simplified_iter_dst_strides,
248
+ simplified_red_shape, simplified_red_src_strides);
249
+ py::ssize_t *packed_shapes_and_strides =
250
+ std::get<0 >(iter_red_metadata_packing_triple_);
251
+ if (packed_shapes_and_strides == nullptr ) {
249
252
throw std::runtime_error (" Unable to allocate memory on device" );
250
253
}
251
- const auto ©_iter_metadata_ev =
252
- std::get<2 >(iter_src_dst_metadata_packing_triple_ );
254
+ const auto ©_metadata_ev =
255
+ std::get<2 >(iter_red_metadata_packing_triple_ );
253
256
254
- const auto &red_metadata_packing_triple_ =
255
- dpctl::tensor::offset_utils::device_allocate_and_pack<py::ssize_t >(
256
- exec_q, host_task_events, simplified_red_shape,
257
- simplified_red_src_strides);
258
- py::ssize_t *red_shape_stride = std::get<0 >(red_metadata_packing_triple_);
259
- if (red_shape_stride == nullptr ) {
260
- sycl::event::wait (host_task_events);
261
- sycl::free (iter_shape_and_strides, exec_q);
262
- throw std::runtime_error (" Unable to allocate memory on device" );
263
- }
264
- const auto ©_red_metadata_ev =
265
- std::get<2 >(red_metadata_packing_triple_);
257
+ py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides;
258
+ py::ssize_t *red_shape_stride = packed_shapes_and_strides + (3 * iter_nd);
266
259
267
260
std::vector<sycl::event> all_deps;
268
- all_deps.reserve (depends.size () + 2 );
261
+ all_deps.reserve (depends.size () + 1 );
269
262
all_deps.resize (depends.size ());
270
263
std::copy (depends.begin (), depends.end (), all_deps.begin ());
271
- all_deps.push_back (copy_iter_metadata_ev);
272
- all_deps.push_back (copy_red_metadata_ev);
264
+ all_deps.push_back (copy_metadata_ev);
273
265
274
266
auto red_ev =
275
267
fn (exec_q, dst_nelems, red_nelems, src_data, dst_data, dst_nd,
@@ -279,9 +271,8 @@ py_boolean_reduction(dpctl::tensor::usm_ndarray src,
279
271
sycl::event temp_cleanup_ev = exec_q.submit ([&](sycl::handler &cgh) {
280
272
cgh.depends_on (red_ev);
281
273
auto ctx = exec_q.get_context ();
282
- cgh.host_task ([ctx, iter_shape_and_strides, red_shape_stride] {
283
- sycl::free (iter_shape_and_strides, ctx);
284
- sycl::free (red_shape_stride, ctx);
274
+ cgh.host_task ([ctx, packed_shapes_and_strides] {
275
+ sycl::free (packed_shapes_and_strides, ctx);
285
276
});
286
277
});
287
278
host_task_events.push_back (temp_cleanup_ev);
0 commit comments