Skip to content

Commit cdf7b2f

Browse files
Merge pull request #1111 from IntelPython/host-task-cleanup
Host task cleanup
2 parents a91eaf3 + fd9286c commit cdf7b2f

File tree

4 files changed

+37
-26
lines changed

4 files changed

+37
-26
lines changed

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ using dpctl::utils::keep_args_alive;
7070

7171
sycl::event _populate_packed_shape_strides_for_copycast_kernel(
7272
sycl::queue exec_q,
73+
std::vector<sycl::event> &host_task_events,
7374
py::ssize_t *device_shape_strides, // to be populated
7475
const std::vector<py::ssize_t> &common_shape,
7576
const std::vector<py::ssize_t> &src_strides,
@@ -102,13 +103,14 @@ sycl::event _populate_packed_shape_strides_for_copycast_kernel(
102103
shp_host_shape_strides->data(), device_shape_strides,
103104
shp_host_shape_strides->size());
104105

105-
exec_q.submit([&](sycl::handler &cgh) {
106+
auto shared_ptr_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
106107
cgh.depends_on(copy_shape_ev);
107108
cgh.host_task([shp_host_shape_strides]() {
108109
// increment shared pointer ref-count to keep it alive
109110
// till copy operation completes;
110111
});
111112
});
113+
host_task_events.push_back(shared_ptr_cleanup_ev);
112114

113115
return copy_shape_ev;
114116
}
@@ -306,26 +308,30 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
306308
throw std::runtime_error("Unabled to allocate device memory");
307309
}
308310

311+
std::vector<sycl::event> host_task_events;
312+
host_task_events.reserve(2);
313+
309314
sycl::event copy_shape_ev =
310315
_populate_packed_shape_strides_for_copycast_kernel(
311-
exec_q, shape_strides, simplified_shape, simplified_src_strides,
312-
simplified_dst_strides);
316+
exec_q, host_task_events, shape_strides, simplified_shape,
317+
simplified_src_strides, simplified_dst_strides);
313318

314319
sycl::event copy_and_cast_generic_ev = copy_and_cast_fn(
315320
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
316321
dst_offset, depends, {copy_shape_ev});
317322

318323
// async free of shape_strides temporary
319324
auto ctx = exec_q.get_context();
320-
exec_q.submit([&](sycl::handler &cgh) {
325+
auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
321326
cgh.depends_on(copy_and_cast_generic_ev);
322327
cgh.host_task(
323328
[ctx, shape_strides]() { sycl::free(shape_strides, ctx); });
324329
});
325330

326-
return std::make_pair(
327-
keep_args_alive(exec_q, {src, dst}, {copy_and_cast_generic_ev}),
328-
copy_and_cast_generic_ev);
331+
host_task_events.push_back(temporaries_cleanup_ev);
332+
333+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events),
334+
temporaries_cleanup_ev);
329335
}
330336

331337
void init_copy_and_cast_usm_to_usm_dispatch_tables(void)

dpctl/tensor/libtensor/source/copy_for_reshape.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,14 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
202202
dst_nd);
203203
}
204204

205+
std::vector<sycl::event> host_task_events;
206+
host_task_events.reserve(2);
207+
205208
// copy packed shapes and strides from host to devices
206209
sycl::event packed_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
207210
packed_host_shapes_strides_shp->data(), packed_shapes_strides,
208211
packed_host_shapes_strides_shp->size());
209-
exec_q.submit([&](sycl::handler &cgh) {
212+
auto shared_ptr_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
210213
cgh.depends_on(packed_shape_strides_copy_ev);
211214
cgh.host_task([packed_host_shapes_strides_shp] {
212215
// Capturing shared pointer ensures that the underlying vector is
@@ -215,6 +218,8 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
215218
});
216219
});
217220

221+
host_task_events.push_back(shared_ptr_cleanup_ev);
222+
218223
char *src_data = src.get_data();
219224
char *dst_data = dst.get_data();
220225

@@ -226,17 +231,18 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
226231
fn(exec_q, shift, src_nelems, src_nd, dst_nd, packed_shapes_strides,
227232
src_data, dst_data, all_deps);
228233

229-
exec_q.submit([&](sycl::handler &cgh) {
234+
auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
230235
cgh.depends_on(copy_for_reshape_event);
231236
auto ctx = exec_q.get_context();
232237
cgh.host_task([packed_shapes_strides, ctx]() {
233238
sycl::free(packed_shapes_strides, ctx);
234239
});
235240
});
236241

237-
return std::make_pair(
238-
keep_args_alive(exec_q, {src, dst}, {copy_for_reshape_event}),
239-
copy_for_reshape_event);
242+
host_task_events.push_back(temporaries_cleanup_ev);
243+
244+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, host_task_events),
245+
temporaries_cleanup_ev);
240246
}
241247

242248
void init_copy_for_reshape_dispatch_vectors(void)

dpctl/tensor/libtensor/source/integer_advanced_indexing.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -180,17 +180,15 @@ _populate_kernel_params(sycl::queue exec_q,
180180
host_along_sh_st_shp->data(), device_along_sh_st,
181181
host_along_sh_st_shp->size());
182182

183-
sycl::event shared_ptr_cleanup_host_task =
184-
exec_q.submit([&](sycl::handler &cgh) {
185-
cgh.depends_on({device_along_sh_st_copy_ev,
186-
device_orthog_sh_st_copy_ev,
187-
device_ind_offsets_copy_ev,
188-
device_ind_sh_st_copy_ev, device_ind_ptrs_copy_ev});
189-
cgh.host_task([host_ind_offsets_shp, host_ind_sh_st_shp,
190-
host_ind_ptrs_shp, host_orthog_sh_st_shp,
191-
host_along_sh_st_shp]() {});
192-
});
193-
host_task_events.push_back(shared_ptr_cleanup_host_task);
183+
sycl::event shared_ptr_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
184+
cgh.depends_on({device_along_sh_st_copy_ev, device_orthog_sh_st_copy_ev,
185+
device_ind_offsets_copy_ev, device_ind_sh_st_copy_ev,
186+
device_ind_ptrs_copy_ev});
187+
cgh.host_task([host_ind_offsets_shp, host_ind_sh_st_shp,
188+
host_ind_ptrs_shp, host_orthog_sh_st_shp,
189+
host_along_sh_st_shp]() {});
190+
});
191+
host_task_events.push_back(shared_ptr_cleanup_ev);
194192

195193
std::vector<sycl::event> sh_st_pack_deps{
196194
device_ind_ptrs_copy_ev, device_ind_sh_st_copy_ev,

dpctl/tensor/libtensor/source/triul_ctor.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ usm_ndarray_triul(sycl::queue exec_q,
250250
dev_shape_and_strides, k, depends, {copy_shape_and_strides});
251251
}
252252

253-
exec_q.submit([&](sycl::handler &cgh) {
253+
auto temporaries_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) {
254254
cgh.depends_on({tri_ev});
255255
auto ctx = exec_q.get_context();
256256
cgh.host_task(
@@ -261,8 +261,9 @@ usm_ndarray_triul(sycl::queue exec_q,
261261
});
262262
});
263263

264-
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {tri_ev}),
265-
tri_ev);
264+
return std::make_pair(
265+
keep_args_alive(exec_q, {src, dst}, {temporaries_cleanup_ev}),
266+
temporaries_cleanup_ev);
266267
}
267268

268269
void init_triul_ctor_dispatch_vectors(void)

0 commit comments

Comments
 (0)