@@ -479,6 +479,46 @@ void simplify_iteration_space(int &nd,
479
479
}
480
480
}
481
481
482
+ sycl::event _populate_packed_shape_strides_for_copycast_kernel (
483
+ sycl::queue exec_q,
484
+ int src_flags,
485
+ int dst_flags,
486
+ py::ssize_t *device_shape_strides, // to be populated
487
+ const std::vector<py::ssize_t > &common_shape,
488
+ const std::vector<py::ssize_t > &src_strides,
489
+ const std::vector<py::ssize_t > &dst_strides)
490
+ {
491
+ using shT = std::vector<py::ssize_t >;
492
+ size_t nd = common_shape.size ();
493
+
494
+ // create host temporary for packed shape and strides managed by shared
495
+ // pointer. Packed vector is concatenation of common_shape, src_stride and
496
+ // std_strides
497
+ std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
498
+ std::copy (common_shape.begin (), common_shape.end (),
499
+ shp_host_shape_strides->begin ());
500
+
501
+ std::copy (src_strides.begin (), src_strides.end (),
502
+ shp_host_shape_strides->begin () + nd);
503
+
504
+ std::copy (dst_strides.begin (), dst_strides.end (),
505
+ shp_host_shape_strides->begin () + 2 * nd);
506
+
507
+ sycl::event copy_shape_ev = exec_q.copy <py::ssize_t >(
508
+ shp_host_shape_strides->data (), device_shape_strides,
509
+ shp_host_shape_strides->size ());
510
+
511
+ exec_q.submit ([&](sycl::handler &cgh) {
512
+ cgh.depends_on (copy_shape_ev);
513
+ cgh.host_task ([shp_host_shape_strides]() {
514
+ // increment shared pointer ref-count to keep it alive
515
+ // till copy operation completes;
516
+ });
517
+ });
518
+
519
+ return copy_shape_ev;
520
+ }
521
+
482
522
std::pair<sycl::event, sycl::event>
483
523
copy_usm_ndarray_into_usm_ndarray (dpctl::tensor::usm_ndarray src,
484
524
dpctl::tensor::usm_ndarray dst,
@@ -677,47 +717,10 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
677
717
throw std::runtime_error (" Unabled to allocate device memory" );
678
718
}
679
719
680
- // create host temporary for packed shape and strides managed by shared
681
- // pointer
682
- std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683
- std::copy (simplified_shape.begin (), simplified_shape.end (),
684
- shp_host_shape_strides->begin ());
685
-
686
- if (src_strides == nullptr ) {
687
- const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688
- ? c_contiguous_strides (nd, shape)
689
- : f_contiguous_strides (nd, shape);
690
- std::copy (src_contig_strides.begin (), src_contig_strides.end (),
691
- shp_host_shape_strides->begin () + nd);
692
- }
693
- else {
694
- std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
695
- shp_host_shape_strides->begin () + nd);
696
- }
697
-
698
- if (dst_strides == nullptr ) {
699
- const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700
- ? c_contiguous_strides (nd, shape)
701
- : f_contiguous_strides (nd, shape);
702
- std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
703
- shp_host_shape_strides->begin () + 2 * nd);
704
- }
705
- else {
706
- std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
707
- shp_host_shape_strides->begin () + nd);
708
- }
709
-
710
720
sycl::event copy_shape_ev =
711
- exec_q.copy <py::ssize_t >(shp_host_shape_strides->data (), shape_strides,
712
- shp_host_shape_strides->size ());
713
-
714
- exec_q.submit ([&](sycl::handler &cgh) {
715
- cgh.depends_on (copy_shape_ev);
716
- cgh.host_task ([shp_host_shape_strides]() {
717
- // increment shared pointer ref-count to keep it alive
718
- // till copy operation completes;
719
- });
720
- });
721
+ _populate_packed_shape_strides_for_copycast_kernel (
722
+ exec_q, nd, src_flags, dst_flags, shape_strides, simplified_shape,
723
+ simplified_src_strides, simplified_dst_strides);
721
724
722
725
sycl::event copy_and_cast_generic_ev = copy_and_cast_fn (
723
726
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
0 commit comments