@@ -385,6 +385,9 @@ py_extract(dpctl::tensor::usm_ndarray src,
385
385
auto fn =
386
386
masked_extract_all_slices_strided_impl_dispatch_vector[src_typeid];
387
387
388
+ assert (dst_shape_vec.size () == 1 );
389
+ assert (dst_strides_vec.size () == 1 );
390
+
388
391
using dpctl::tensor::offset_utils::device_allocate_and_pack;
389
392
const auto &ptr_size_event_tuple1 =
390
393
device_allocate_and_pack<py::ssize_t >(
@@ -397,9 +400,6 @@ py_extract(dpctl::tensor::usm_ndarray src,
397
400
sycl::event copy_src_shape_strides_ev =
398
401
std::get<2 >(ptr_size_event_tuple1);
399
402
400
- assert (dst_shape_vec.size () == 1 );
401
- assert (dst_strides_vec.size () == 1 );
402
-
403
403
std::vector<sycl::event> all_deps;
404
404
all_deps.reserve (depends.size () + 1 );
405
405
all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
@@ -469,40 +469,31 @@ py_extract(dpctl::tensor::usm_ndarray src,
469
469
simplified_ortho_shape, simplified_ortho_src_strides,
470
470
simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);
471
471
472
+ assert (masked_dst_shape.size () == 1 );
473
+ assert (masked_dst_strides.size () == 1 );
474
+
472
475
using dpctl::tensor::offset_utils::device_allocate_and_pack;
473
476
const auto &ptr_size_event_tuple1 =
474
477
device_allocate_and_pack<py::ssize_t >(
475
478
exec_q, host_task_events, simplified_ortho_shape,
476
- simplified_ortho_src_strides, simplified_ortho_dst_strides);
477
- py:: ssize_t *packed_ortho_src_dst_shape_strides =
478
- std::get<0 >(ptr_size_event_tuple1);
479
- if (packed_ortho_src_dst_shape_strides == nullptr ) {
479
+ simplified_ortho_src_strides, simplified_ortho_dst_strides,
480
+ masked_src_shape, masked_src_strides);
481
+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
482
+ if (packed_shapes_strides == nullptr ) {
480
483
throw std::runtime_error (" Unable to allocate device memory" );
481
484
}
482
- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
485
+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
483
486
484
- const auto &ptr_size_event_tuple2 =
485
- device_allocate_and_pack<py::ssize_t >(
486
- exec_q, host_task_events, masked_src_shape, masked_src_strides);
487
+ py::ssize_t *packed_ortho_src_dst_shape_strides = packed_shapes_strides;
487
488
py::ssize_t *packed_masked_src_shape_strides =
488
- std::get<0 >(ptr_size_event_tuple2);
489
- if (packed_masked_src_shape_strides == nullptr ) {
490
- copy_shape_strides_ev1.wait ();
491
- sycl::free (packed_ortho_src_dst_shape_strides, exec_q);
492
- throw std::runtime_error (" Unable to allocate device memory" );
493
- }
494
- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
495
-
496
- assert (masked_dst_shape.size () == 1 );
497
- assert (masked_dst_strides.size () == 1 );
489
+ packed_shapes_strides + (3 * ortho_nd);
498
490
499
491
std::vector<sycl::event> all_deps;
500
- all_deps.reserve (depends.size () + 2 );
492
+ all_deps.reserve (depends.size () + 1 );
501
493
all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
502
- all_deps.push_back (copy_shape_strides_ev1);
503
- all_deps.push_back (copy_shape_strides_ev2);
494
+ all_deps.push_back (copy_shapes_strides_ev);
504
495
505
- assert (all_deps.size () == depends.size () + 2 );
496
+ assert (all_deps.size () == depends.size () + 1 );
506
497
507
498
// OrthogIndexerT orthog_src_dst_indexer_, MaskedIndexerT
508
499
// masked_src_indexer_, MaskedIndexerT masked_dst_indexer_
@@ -520,10 +511,8 @@ py_extract(dpctl::tensor::usm_ndarray src,
520
511
exec_q.submit ([&](sycl::handler &cgh) {
521
512
cgh.depends_on (extract_ev);
522
513
auto ctx = exec_q.get_context ();
523
- cgh.host_task ([ctx, packed_ortho_src_dst_shape_strides,
524
- packed_masked_src_shape_strides] {
525
- sycl::free (packed_ortho_src_dst_shape_strides, ctx);
526
- sycl::free (packed_masked_src_shape_strides, ctx);
514
+ cgh.host_task ([ctx, packed_shapes_strides] {
515
+ sycl::free (packed_shapes_strides, ctx);
527
516
});
528
517
});
529
518
host_task_events.push_back (cleanup_tmp_allocations_ev);
@@ -684,13 +673,16 @@ py_place(dpctl::tensor::usm_ndarray dst,
684
673
auto rhs_shape_vec = rhs.get_shape_vector ();
685
674
auto rhs_strides_vec = rhs.get_strides_vector ();
686
675
687
- sycl::event extract_ev ;
676
+ sycl::event place_ev ;
688
677
std::vector<sycl::event> host_task_events{};
689
678
if (axis_start == 0 && axis_end == dst_nd) {
690
679
// empty orthogonal directions
691
680
auto fn =
692
681
masked_place_all_slices_strided_impl_dispatch_vector[dst_typeid];
693
682
683
+ assert (rhs_shape_vec.size () == 1 );
684
+ assert (rhs_strides_vec.size () == 1 );
685
+
694
686
using dpctl::tensor::offset_utils::device_allocate_and_pack;
695
687
const auto &ptr_size_event_tuple1 =
696
688
device_allocate_and_pack<py::ssize_t >(
@@ -703,23 +695,20 @@ py_place(dpctl::tensor::usm_ndarray dst,
703
695
sycl::event copy_dst_shape_strides_ev =
704
696
std::get<2 >(ptr_size_event_tuple1);
705
697
706
- assert (rhs_shape_vec.size () == 1 );
707
- assert (rhs_strides_vec.size () == 1 );
708
-
709
698
std::vector<sycl::event> all_deps;
710
699
all_deps.reserve (depends.size () + 1 );
711
700
all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
712
701
all_deps.push_back (copy_dst_shape_strides_ev);
713
702
714
703
assert (all_deps.size () == depends.size () + 1 );
715
704
716
- extract_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p,
717
- rhs_data_p, dst_nd, packed_dst_shape_strides,
718
- rhs_shape_vec[ 0 ], rhs_strides_vec[0 ], all_deps);
705
+ place_ev = fn (exec_q, cumsum_sz, dst_data_p, cumsum_data_p, rhs_data_p ,
706
+ dst_nd, packed_dst_shape_strides, rhs_shape_vec[ 0 ] ,
707
+ rhs_strides_vec[0 ], all_deps);
719
708
720
709
sycl::event cleanup_tmp_allocations_ev =
721
710
exec_q.submit ([&](sycl::handler &cgh) {
722
- cgh.depends_on (extract_ev );
711
+ cgh.depends_on (place_ev );
723
712
auto ctx = exec_q.get_context ();
724
713
cgh.host_task ([ctx, packed_dst_shape_strides] {
725
714
sycl::free (packed_dst_shape_strides, ctx);
@@ -774,69 +763,59 @@ py_place(dpctl::tensor::usm_ndarray dst,
774
763
simplified_ortho_shape, simplified_ortho_dst_strides,
775
764
simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);
776
765
766
+ assert (masked_rhs_shape.size () == 1 );
767
+ assert (masked_rhs_strides.size () == 1 );
768
+
777
769
using dpctl::tensor::offset_utils::device_allocate_and_pack;
778
770
const auto &ptr_size_event_tuple1 =
779
771
device_allocate_and_pack<py::ssize_t >(
780
772
exec_q, host_task_events, simplified_ortho_shape,
781
- simplified_ortho_dst_strides, simplified_ortho_rhs_strides);
782
- py:: ssize_t *packed_ortho_dst_rhs_shape_strides =
783
- std::get<0 >(ptr_size_event_tuple1);
784
- if (packed_ortho_dst_rhs_shape_strides == nullptr ) {
773
+ simplified_ortho_dst_strides, simplified_ortho_rhs_strides,
774
+ masked_dst_shape, masked_dst_strides);
775
+ py:: ssize_t *packed_shapes_strides = std::get<0 >(ptr_size_event_tuple1);
776
+ if (packed_shapes_strides == nullptr ) {
785
777
throw std::runtime_error (" Unable to allocate device memory" );
786
778
}
787
- sycl::event copy_shape_strides_ev1 = std::get<2 >(ptr_size_event_tuple1);
779
+ sycl::event copy_shapes_strides_ev = std::get<2 >(ptr_size_event_tuple1);
788
780
789
- auto ptr_size_event_tuple2 = device_allocate_and_pack<py::ssize_t >(
790
- exec_q, host_task_events, masked_dst_shape, masked_dst_strides);
781
+ py::ssize_t *packed_ortho_dst_rhs_shape_strides = packed_shapes_strides;
791
782
py::ssize_t *packed_masked_dst_shape_strides =
792
- std::get<0 >(ptr_size_event_tuple2);
793
- if (packed_masked_dst_shape_strides == nullptr ) {
794
- copy_shape_strides_ev1.wait ();
795
- sycl::free (packed_ortho_dst_rhs_shape_strides, exec_q);
796
- throw std::runtime_error (" Unable to allocate device memory" );
797
- }
798
- sycl::event copy_shape_strides_ev2 = std::get<2 >(ptr_size_event_tuple2);
799
-
800
- assert (masked_rhs_shape.size () == 1 );
801
- assert (masked_rhs_strides.size () == 1 );
783
+ packed_shapes_strides + (3 * ortho_nd);
802
784
803
785
std::vector<sycl::event> all_deps;
804
- all_deps.reserve (depends.size () + 2 );
786
+ all_deps.reserve (depends.size () + 1 );
805
787
all_deps.insert (all_deps.end (), depends.begin (), depends.end ());
806
- all_deps.push_back (copy_shape_strides_ev1);
807
- all_deps.push_back (copy_shape_strides_ev2);
808
-
809
- assert (all_deps.size () == depends.size () + 2 );
810
-
811
- extract_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
812
- cumsum_data_p, rhs_data_p,
813
- // data to build orthog_dst_rhs_indexer
814
- ortho_nd, packed_ortho_dst_rhs_shape_strides,
815
- ortho_dst_offset, ortho_rhs_offset,
816
- // data to build masked_dst_indexer
817
- masked_dst_nd, packed_masked_dst_shape_strides,
818
- // data to build masked_dst_indexer,
819
- masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
788
+ all_deps.push_back (copy_shapes_strides_ev);
789
+
790
+ assert (all_deps.size () == depends.size () + 1 );
791
+
792
+ place_ev = fn (exec_q, ortho_nelems, masked_dst_nelems, dst_data_p,
793
+ cumsum_data_p, rhs_data_p,
794
+ // data to build orthog_dst_rhs_indexer
795
+ ortho_nd, packed_ortho_dst_rhs_shape_strides,
796
+ ortho_dst_offset, ortho_rhs_offset,
797
+ // data to build masked_dst_indexer
798
+ masked_dst_nd, packed_masked_dst_shape_strides,
799
+ // data to build masked_dst_indexer,
800
+ masked_rhs_shape[0 ], masked_rhs_strides[0 ], all_deps);
820
801
821
802
sycl::event cleanup_tmp_allocations_ev =
822
803
exec_q.submit ([&](sycl::handler &cgh) {
823
- cgh.depends_on (extract_ev );
804
+ cgh.depends_on (place_ev );
824
805
auto ctx = exec_q.get_context ();
825
- cgh.host_task ([ctx, packed_ortho_dst_rhs_shape_strides,
826
- packed_masked_dst_shape_strides] {
827
- sycl::free (packed_ortho_dst_rhs_shape_strides, ctx);
828
- sycl::free (packed_masked_dst_shape_strides, ctx);
806
+ cgh.host_task ([ctx, packed_shapes_strides] {
807
+ sycl::free (packed_shapes_strides, ctx);
829
808
});
830
809
});
831
810
host_task_events.push_back (cleanup_tmp_allocations_ev);
832
811
}
833
812
834
- host_task_events.push_back (extract_ev );
813
+ host_task_events.push_back (place_ev );
835
814
836
815
sycl::event py_obj_management_host_task_ev = dpctl::utils::keep_args_alive (
837
816
exec_q, {dst, cumsum, rhs}, host_task_events);
838
817
839
- return std::make_pair (py_obj_management_host_task_ev, extract_ev );
818
+ return std::make_pair (py_obj_management_host_task_ev, place_ev );
840
819
}
841
820
842
821
// Non-zero
0 commit comments