Skip to content

Commit ffb42f9

Browse files
Modularized utility for packing shapes/strides into device allocation for copy-and-cast operation between two usm_ndarrays
1 parent 7ab929f commit ffb42f9

File tree

1 file changed

+43
-40
lines changed

1 file changed

+43
-40
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,46 @@ void simplify_iteration_space(int &nd,
479479
}
480480
}
481481

482+
sycl::event _populate_packed_shape_strides_for_copycast_kernel(
483+
sycl::queue exec_q,
484+
int src_flags,
485+
int dst_flags,
486+
py::ssize_t *device_shape_strides, // to be populated
487+
const std::vector<py::ssize_t> &common_shape,
488+
const std::vector<py::ssize_t> &src_strides,
489+
const std::vector<py::ssize_t> &dst_strides)
490+
{
491+
using shT = std::vector<py::ssize_t>;
492+
size_t nd = common_shape.size();
493+
494+
// create host temporary for packed shape and strides managed by shared
495+
// pointer. Packed vector is concatenation of common_shape, src_stride and
496+
// std_strides
497+
std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
498+
std::copy(common_shape.begin(), common_shape.end(),
499+
shp_host_shape_strides->begin());
500+
501+
std::copy(src_strides.begin(), src_strides.end(),
502+
shp_host_shape_strides->begin() + nd);
503+
504+
std::copy(dst_strides.begin(), dst_strides.end(),
505+
shp_host_shape_strides->begin() + 2 * nd);
506+
507+
sycl::event copy_shape_ev = exec_q.copy<py::ssize_t>(
508+
shp_host_shape_strides->data(), device_shape_strides,
509+
shp_host_shape_strides->size());
510+
511+
exec_q.submit([&](sycl::handler &cgh) {
512+
cgh.depends_on(copy_shape_ev);
513+
cgh.host_task([shp_host_shape_strides]() {
514+
// increment shared pointer ref-count to keep it alive
515+
// till copy operation completes;
516+
});
517+
});
518+
519+
return copy_shape_ev;
520+
}
521+
482522
std::pair<sycl::event, sycl::event>
483523
copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
484524
dpctl::tensor::usm_ndarray dst,
@@ -677,47 +717,10 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
677717
throw std::runtime_error("Unabled to allocate device memory");
678718
}
679719

680-
// create host temporary for packed shape and strides managed by shared
681-
// pointer
682-
std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683-
std::copy(simplified_shape.begin(), simplified_shape.end(),
684-
shp_host_shape_strides->begin());
685-
686-
if (src_strides == nullptr) {
687-
const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688-
? c_contiguous_strides(nd, shape)
689-
: f_contiguous_strides(nd, shape);
690-
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
691-
shp_host_shape_strides->begin() + nd);
692-
}
693-
else {
694-
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
695-
shp_host_shape_strides->begin() + nd);
696-
}
697-
698-
if (dst_strides == nullptr) {
699-
const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700-
? c_contiguous_strides(nd, shape)
701-
: f_contiguous_strides(nd, shape);
702-
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
703-
shp_host_shape_strides->begin() + 2 * nd);
704-
}
705-
else {
706-
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
707-
shp_host_shape_strides->begin() + nd);
708-
}
709-
710720
sycl::event copy_shape_ev =
711-
exec_q.copy<py::ssize_t>(shp_host_shape_strides->data(), shape_strides,
712-
shp_host_shape_strides->size());
713-
714-
exec_q.submit([&](sycl::handler &cgh) {
715-
cgh.depends_on(copy_shape_ev);
716-
cgh.host_task([shp_host_shape_strides]() {
717-
// increment shared pointer ref-count to keep it alive
718-
// till copy operation completes;
719-
});
720-
});
721+
_populate_packed_shape_strides_for_copycast_kernel(
722+
exec_q, nd, src_flags, dst_flags, shape_strides, simplified_shape,
723+
simplified_src_strides, simplified_dst_strides);
721724

722725
sycl::event copy_and_cast_generic_ev = copy_and_cast_fn(
723726
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,

0 commit comments

Comments
 (0)