Skip to content

Simplified API for simplify_iteration_space* functions #1188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 6 additions & 23 deletions dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
}

const py::ssize_t *shape = mask.get_shape_raw();
const py::ssize_t *strides = mask.get_strides_raw();
auto const &strides_vector = mask.get_strides_vector();

using shT = std::vector<py::ssize_t>;
shT simplified_shape;
Expand All @@ -187,13 +187,9 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
int mask_nd = mask.get_ndim();
int nd = mask_nd;

constexpr py::ssize_t itemsize = 1; // in elements
bool is_c_contig = mask.is_c_contiguous();
bool is_f_contig = mask.is_f_contiguous();

dpctl::tensor::py_internal::simplify_iteration_space_1(
nd, shape, strides, itemsize, is_c_contig, is_f_contig,
simplified_shape, simplified_strides, offset);
nd, shape, strides_vector, simplified_shape, simplified_strides,
offset);

if (nd == 1 && simplified_strides[0] == 1) {
auto fn = mask_positions_contig_dispatch_vector[mask_typeid];
Expand Down Expand Up @@ -463,19 +459,13 @@ py_extract(dpctl::tensor::usm_ndarray src,
std::vector<py::ssize_t> simplified_ortho_dst_strides;

const py::ssize_t *_shape = ortho_src_shape.data();
const py::ssize_t *_src_strides = ortho_src_strides.data();
const py::ssize_t *_dst_strides = ortho_dst_strides.data();
constexpr py::ssize_t _itemsize = 1; // in elements

constexpr bool is_c_contig = false;
constexpr bool is_f_contig = false;

py::ssize_t ortho_src_offset(0);
py::ssize_t ortho_dst_offset(0);

dpctl::tensor::py_internal::simplify_iteration_space(
ortho_nd, _shape, _src_strides, _itemsize, is_c_contig, is_f_contig,
_dst_strides, _itemsize, is_c_contig, is_f_contig,
ortho_nd, _shape, ortho_src_strides, ortho_dst_strides,
// output
simplified_ortho_shape, simplified_ortho_src_strides,
simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);

Expand Down Expand Up @@ -775,19 +765,12 @@ py_place(dpctl::tensor::usm_ndarray dst,
std::vector<py::ssize_t> simplified_ortho_rhs_strides;

const py::ssize_t *_shape = ortho_dst_shape.data();
const py::ssize_t *_dst_strides = ortho_dst_strides.data();
const py::ssize_t *_rhs_strides = ortho_rhs_strides.data();
constexpr py::ssize_t _itemsize = 1; // in elements

constexpr bool is_c_contig = false;
constexpr bool is_f_contig = false;

py::ssize_t ortho_dst_offset(0);
py::ssize_t ortho_rhs_offset(0);

dpctl::tensor::py_internal::simplify_iteration_space(
ortho_nd, _shape, _dst_strides, _itemsize, is_c_contig, is_f_contig,
_rhs_strides, _itemsize, is_c_contig, is_f_contig,
ortho_nd, _shape, ortho_dst_strides, ortho_rhs_strides,
simplified_ortho_shape, simplified_ortho_dst_strides,
simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);

Expand Down
21 changes: 8 additions & 13 deletions dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
copy_ev);
}

const py::ssize_t *src_strides = src.get_strides_raw();
const py::ssize_t *dst_strides = dst.get_strides_raw();
auto const &src_strides = src.get_strides_vector();
auto const &dst_strides = dst.get_strides_vector();

using shT = std::vector<py::ssize_t>;
shT simplified_shape;
Expand All @@ -180,25 +180,20 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
int nd = src_nd;
const py::ssize_t *shape = src_shape;

constexpr py::ssize_t src_itemsize = 1; // in elements
constexpr py::ssize_t dst_itemsize = 1; // in elements

// all args except itemsizes and is_?_contig bools can be modified by
// reference
// nd, simplified_* and *_offset are modified by reference
dpctl::tensor::py_internal::simplify_iteration_space(
nd, shape, src_strides, src_itemsize, is_src_c_contig, is_src_f_contig,
dst_strides, dst_itemsize, is_dst_c_contig, is_dst_f_contig,
nd, shape, src_strides, dst_strides,
// output
simplified_shape, simplified_src_strides, simplified_dst_strides,
src_offset, dst_offset);

if (nd < 2) {
if (nd == 1) {
std::array<py::ssize_t, 1> shape_arr = {shape[0]};
// strides may be null
std::array<py::ssize_t, 1> shape_arr = {simplified_shape[0]};
std::array<py::ssize_t, 1> src_strides_arr = {
(src_strides ? src_strides[0] : 1)};
simplified_src_strides[0]};
std::array<py::ssize_t, 1> dst_strides_arr = {
(dst_strides ? dst_strides[0] : 1)};
simplified_dst_strides[0]};

sycl::event copy_and_cast_1d_event;
if ((src_strides_arr[0] == 1) && (dst_strides_arr[0] == 1) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//===----------------------------------------------------------------------===//

#include <CL/sycl.hpp>
#include <algorithm>
#include <vector>

#include "dpctl4pybind11.hpp"
Expand Down Expand Up @@ -143,10 +144,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
}
}

const py::ssize_t *src_strides =
npy_src.strides(); // N.B.: strides in bytes
const py::ssize_t *dst_strides =
dst.get_strides_raw(); // N.B.: strides in elements
auto const &dst_strides =
dst.get_strides_vector(); // N.B.: strides in elements

using shT = std::vector<py::ssize_t>;
shT simplified_shape;
Expand All @@ -155,23 +154,42 @@ void copy_numpy_ndarray_into_usm_ndarray(
py::ssize_t src_offset(0);
py::ssize_t dst_offset(0);

py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes
constexpr py::ssize_t dst_itemsize = 1; // item size in elements

int nd = src_ndim;
const py::ssize_t *shape = src_shape;

const py::ssize_t *src_strides_p =
npy_src.strides(); // N.B.: strides in bytes
py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes

bool is_src_c_contig = ((src_flags & py::array::c_style) != 0);
bool is_src_f_contig = ((src_flags & py::array::f_style) != 0);

bool is_dst_c_contig = dst.is_c_contiguous();
bool is_dst_f_contig = dst.is_f_contiguous();
shT src_strides_in_elems;
if (src_strides_p) {
src_strides_in_elems.resize(nd);
// copy and convert strides from bytes to elements
std::transform(
src_strides_p, src_strides_p + nd, std::begin(src_strides_in_elems),
[src_itemsize](py::ssize_t el) { return el / src_itemsize; });
}
else {
if (is_src_c_contig) {
src_strides_in_elems =
dpctl::tensor::c_contiguous_strides(nd, src_shape);
}
else if (is_src_f_contig) {
src_strides_in_elems =
dpctl::tensor::f_contiguous_strides(nd, src_shape);
}
else {
throw py::value_error("NumPy source array has null strides but is "
"neither C- nor F-contiguous.");
}
}

// all args except itemsizes and is_?_contig bools can be modified by
// reference
simplify_iteration_space(nd, shape, src_strides, src_itemsize,
is_src_c_contig, is_src_f_contig, dst_strides,
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
// nd, simplified_* vectors and offsets are modified by reference
simplify_iteration_space(nd, shape, src_strides_in_elems, dst_strides,
// outputs
simplified_shape, simplified_src_strides,
simplified_dst_strides, src_offset, dst_offset);

Expand All @@ -186,18 +204,16 @@ void copy_numpy_ndarray_into_usm_ndarray(
simplified_shape.push_back(1);

simplified_src_strides.reserve(nd);
simplified_src_strides.push_back(src_itemsize);
simplified_src_strides.push_back(1);

simplified_dst_strides.reserve(nd);
simplified_dst_strides.push_back(dst_itemsize);
simplified_dst_strides.push_back(1);
}

// Minumum and maximum element offsets for source np.ndarray
py::ssize_t npy_src_min_nelem_offset(0);
py::ssize_t npy_src_max_nelem_offset(0);
for (int i = 0; i < nd; ++i) {
// convert source strides from bytes to elements
simplified_src_strides[i] = simplified_src_strides[i] / src_itemsize;
if (simplified_src_strides[i] < 0) {
npy_src_min_nelem_offset +=
simplified_src_strides[i] * (simplified_shape[i] - 1);
Expand Down
Loading