Skip to content

Commit b2a66d2

Browse files
authored
Merge pull request #1207 from IntelPython/refactor-host-tasks
Refactored boolean advanced indexing host tasks
2 parents 1320d39 + 7a62edb commit b2a66d2

File tree

1 file changed

+55
-76
lines changed

1 file changed

+55
-76
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 55 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ py_extract(dpctl::tensor::usm_ndarray src,
385385
auto fn =
386386
masked_extract_all_slices_strided_impl_dispatch_vector[src_typeid];
387387

388+
assert(dst_shape_vec.size() == 1);
389+
assert(dst_strides_vec.size() == 1);
390+
388391
using dpctl::tensor::offset_utils::device_allocate_and_pack;
389392
const auto &ptr_size_event_tuple1 =
390393
device_allocate_and_pack<py::ssize_t>(
@@ -397,9 +400,6 @@ py_extract(dpctl::tensor::usm_ndarray src,
397400
sycl::event copy_src_shape_strides_ev =
398401
std::get<2>(ptr_size_event_tuple1);
399402

400-
assert(dst_shape_vec.size() == 1);
401-
assert(dst_strides_vec.size() == 1);
402-
403403
std::vector<sycl::event> all_deps;
404404
all_deps.reserve(depends.size() + 1);
405405
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
@@ -469,40 +469,31 @@ py_extract(dpctl::tensor::usm_ndarray src,
469469
simplified_ortho_shape, simplified_ortho_src_strides,
470470
simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);
471471

472+
assert(masked_dst_shape.size() == 1);
473+
assert(masked_dst_strides.size() == 1);
474+
472475
using dpctl::tensor::offset_utils::device_allocate_and_pack;
473476
const auto &ptr_size_event_tuple1 =
474477
device_allocate_and_pack<py::ssize_t>(
475478
exec_q, host_task_events, simplified_ortho_shape,
476-
simplified_ortho_src_strides, simplified_ortho_dst_strides);
477-
py::ssize_t *packed_ortho_src_dst_shape_strides =
478-
std::get<0>(ptr_size_event_tuple1);
479-
if (packed_ortho_src_dst_shape_strides == nullptr) {
479+
simplified_ortho_src_strides, simplified_ortho_dst_strides,
480+
masked_src_shape, masked_src_strides);
481+
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
482+
if (packed_shapes_strides == nullptr) {
480483
throw std::runtime_error("Unable to allocate device memory");
481484
}
482-
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
485+
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);
483486

484-
const auto &ptr_size_event_tuple2 =
485-
device_allocate_and_pack<py::ssize_t>(
486-
exec_q, host_task_events, masked_src_shape, masked_src_strides);
487+
py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
487488
py::ssize_t *packed_masked_src_shape_strides =
488-
std::get<0>(ptr_size_event_tuple2);
489-
if (packed_masked_src_shape_strides == nullptr) {
490-
copy_shape_strides_ev1.wait();
491-
sycl::free(packed_ortho_src_dst_shape_strides, exec_q);
492-
throw std::runtime_error("Unable to allocate device memory");
493-
}
494-
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
495-
496-
assert(masked_dst_shape.size() == 1);
497-
assert(masked_dst_strides.size() == 1);
489+
packed_shapes_strides + (3 * ortho_nd);
498490

499491
std::vector<sycl::event> all_deps;
500-
all_deps.reserve(depends.size() + 2);
492+
all_deps.reserve(depends.size() + 1);
501493
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
502-
all_deps.push_back(copy_shape_strides_ev1);
503-
all_deps.push_back(copy_shape_strides_ev2);
494+
all_deps.push_back(copy_shapes_strides_ev);
504495

505-
assert(all_deps.size() == depends.size() + 2);
496+
assert(all_deps.size() == depends.size() + 1);
506497

507498
// OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
508499
// masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
@@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
520511
exec_q.submit([&](sycl::handler &cgh) {
521512
cgh.depends_on(extract_ev);
522513
auto ctx = exec_q.get_context();
523-
cgh.host_task([ctx, packed_ortho_src_dst_shape_strides,
524-
packed_masked_src_shape_strides] {
525-
sycl::free(packed_ortho_src_dst_shape_strides, ctx);
526-
sycl::free(packed_masked_src_shape_strides, ctx);
514+
cgh.host_task([ctx, packed_shapes_strides] {
515+
sycl::free(packed_shapes_strides, ctx);
527516
});
528517
});
529518
host_task_events.push_back(cleanup_tmp_allocations_ev);
@@ -684,13 +673,16 @@ py_place(dpctl::tensor::usm_ndarray dst,
684673
auto rhs_shape_vec = rhs.get_shape_vector();
685674
auto rhs_strides_vec = rhs.get_strides_vector();
686675

687-
sycl::event extract_ev;
676+
sycl::event place_ev;
688677
std::vector<sycl::event> host_task_events{};
689678
if (axis_start == 0 && axis_end == dst_nd) {
690679
// empty orthogonal directions
691680
auto fn =
692681
masked_place_all_slices_strided_impl_dispatch_vector[dst_typeid];
693682

683+
assert(rhs_shape_vec.size() == 1);
684+
assert(rhs_strides_vec.size() == 1);
685+
694686
using dpctl::tensor::offset_utils::device_allocate_and_pack;
695687
const auto &ptr_size_event_tuple1 =
696688
device_allocate_and_pack<py::ssize_t>(
@@ -703,23 +695,20 @@ py_place(dpctl::tensor::usm_ndarray dst,
703695
sycl::event copy_dst_shape_strides_ev =
704696
std::get<2>(ptr_size_event_tuple1);
705697

706-
assert(rhs_shape_vec.size() == 1);
707-
assert(rhs_strides_vec.size() == 1);
708-
709698
std::vector<sycl::event> all_deps;
710699
all_deps.reserve(depends.size() + 1);
711700
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
712701
all_deps.push_back(copy_dst_shape_strides_ev);
713702

714703
assert(all_deps.size() == depends.size() + 1);
715704

716-
extract_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
717-
rhs_data_p, dst_nd, packed_dst_shape_strides,
718-
rhs_shape_vec[0], rhs_strides_vec[0], all_deps);
705+
place_ev = fn(exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p,
706+
dst_nd, packed_dst_shape_strides, rhs_shape_vec[0],
707+
rhs_strides_vec[0], all_deps);
719708

720709
sycl::event cleanup_tmp_allocations_ev =
721710
exec_q.submit([&](sycl::handler &cgh) {
722-
cgh.depends_on(extract_ev);
711+
cgh.depends_on(place_ev);
723712
auto ctx = exec_q.get_context();
724713
cgh.host_task([ctx, packed_dst_shape_strides] {
725714
sycl::free(packed_dst_shape_strides, ctx);
@@ -774,69 +763,59 @@ py_place(dpctl::tensor::usm_ndarray dst,
774763
simplified_ortho_shape, simplified_ortho_dst_strides,
775764
simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);
776765

766+
assert(masked_rhs_shape.size() == 1);
767+
assert(masked_rhs_strides.size() == 1);
768+
777769
using dpctl::tensor::offset_utils::device_allocate_and_pack;
778770
const auto &ptr_size_event_tuple1 =
779771
device_allocate_and_pack<py::ssize_t>(
780772
exec_q, host_task_events, simplified_ortho_shape,
781-
simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
782-
py::ssize_t *packed_ortho_dst_rhs_shape_strides =
783-
std::get<0>(ptr_size_event_tuple1);
784-
if (packed_ortho_dst_rhs_shape_strides == nullptr) {
773+
simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
774+
masked_dst_shape, masked_dst_strides);
775+
py::ssize_t *packed_shapes_strides = std::get<0>(ptr_size_event_tuple1);
776+
if (packed_shapes_strides == nullptr) {
785777
throw std::runtime_error("Unable to allocate device memory");
786778
}
787-
sycl::event copy_shape_strides_ev1 = std::get<2>(ptr_size_event_tuple1);
779+
sycl::event copy_shapes_strides_ev = std::get<2>(ptr_size_event_tuple1);
788780

789-
auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t>(
790-
exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
781+
py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
791782
py::ssize_t *packed_masked_dst_shape_strides =
792-
std::get<0>(ptr_size_event_tuple2);
793-
if (packed_masked_dst_shape_strides == nullptr) {
794-
copy_shape_strides_ev1.wait();
795-
sycl::free(packed_ortho_dst_rhs_shape_strides, exec_q);
796-
throw std::runtime_error("Unable to allocate device memory");
797-
}
798-
sycl::event copy_shape_strides_ev2 = std::get<2>(ptr_size_event_tuple2);
799-
800-
assert(masked_rhs_shape.size() == 1);
801-
assert(masked_rhs_strides.size() == 1);
783+
packed_shapes_strides + (3 * ortho_nd);
802784

803785
std::vector<sycl::event> all_deps;
804-
all_deps.reserve(depends.size() + 2);
786+
all_deps.reserve(depends.size() + 1);
805787
all_deps.insert(all_deps.end(), depends.begin(), depends.end());
806-
all_deps.push_back(copy_shape_strides_ev1);
807-
all_deps.push_back(copy_shape_strides_ev2);
808-
809-
assert(all_deps.size() == depends.size() + 2);
810-
811-
extract_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
812-
cumsum_data_p, rhs_data_p,
813-
// data to build orthog_dst_rhs_indexer
814-
ortho_nd, packed_ortho_dst_rhs_shape_strides,
815-
ortho_dst_offset, ortho_rhs_offset,
816-
// data to build masked_dst_indexer
817-
masked_dst_nd, packed_masked_dst_shape_strides,
818-
// data to build masked_dst_indexer,
819-
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);
788+
all_deps.push_back(copy_shapes_strides_ev);
789+
790+
assert(all_deps.size() == depends.size() + 1);
791+
792+
place_ev = fn(exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
793+
cumsum_data_p, rhs_data_p,
794+
// data to build orthog_dst_rhs_indexer
795+
ortho_nd, packed_ortho_dst_rhs_shape_strides,
796+
ortho_dst_offset, ortho_rhs_offset,
797+
// data to build masked_dst_indexer
798+
masked_dst_nd, packed_masked_dst_shape_strides,
799+
// data to build masked_dst_indexer,
800+
masked_rhs_shape[0], masked_rhs_strides[0], all_deps);
820801

821802
sycl::event cleanup_tmp_allocations_ev =
822803
exec_q.submit([&](sycl::handler &cgh) {
823-
cgh.depends_on(extract_ev);
804+
cgh.depends_on(place_ev);
824805
auto ctx = exec_q.get_context();
825-
cgh.host_task([ctx, packed_ortho_dst_rhs_shape_strides,
826-
packed_masked_dst_shape_strides] {
827-
sycl::free(packed_ortho_dst_rhs_shape_strides, ctx);
828-
sycl::free(packed_masked_dst_shape_strides, ctx);
806+
cgh.host_task([ctx, packed_shapes_strides] {
807+
sycl::free(packed_shapes_strides, ctx);
829808
});
830809
});
831810
host_task_events.push_back(cleanup_tmp_allocations_ev);
832811
}
833812

834-
host_task_events.push_back(extract_ev);
813+
host_task_events.push_back(place_ev);
835814

836815
sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive(
837816
exec_q, {dst, cumsum, rhs}, host_task_events);
838817

839-
return std::make_pair(py_obj_management_host_task_ev, extract_ev);
818+
return std::make_pair(py_obj_management_host_task_ev, place_ev);
840819
}
841820

842821
// Non-zero

0 commit comments

Comments
 (0)