Skip to content

Commit e01e270

Browse files
Merge pull request #1188 from IntelPython/streamline-iter-space-simplifier-api
Simplified API for simplify_iteration_space* functions
2 parents 4a2a79f + 6ed8a7a commit e01e270

File tree

7 files changed

+171
-591
lines changed

7 files changed

+171
-591
lines changed

dpctl/tensor/libtensor/source/boolean_advanced_indexing.cpp

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
177177
}
178178

179179
const py::ssize_t *shape = mask.get_shape_raw();
180-
const py::ssize_t *strides = mask.get_strides_raw();
180+
auto const &strides_vector = mask.get_strides_vector();
181181

182182
using shT = std::vector<py::ssize_t>;
183183
shT simplified_shape;
@@ -187,13 +187,9 @@ size_t py_mask_positions(dpctl::tensor::usm_ndarray mask,
187187
int mask_nd = mask.get_ndim();
188188
int nd = mask_nd;
189189

190-
constexpr py::ssize_t itemsize = 1; // in elements
191-
bool is_c_contig = mask.is_c_contiguous();
192-
bool is_f_contig = mask.is_f_contiguous();
193-
194190
dpctl::tensor::py_internal::simplify_iteration_space_1(
195-
nd, shape, strides, itemsize, is_c_contig, is_f_contig,
196-
simplified_shape, simplified_strides, offset);
191+
nd, shape, strides_vector, simplified_shape, simplified_strides,
192+
offset);
197193

198194
if (nd == 1 && simplified_strides[0] == 1) {
199195
auto fn = mask_positions_contig_dispatch_vector[mask_typeid];
@@ -463,19 +459,13 @@ py_extract(dpctl::tensor::usm_ndarray src,
463459
std::vector<py::ssize_t> simplified_ortho_dst_strides;
464460

465461
const py::ssize_t *_shape = ortho_src_shape.data();
466-
const py::ssize_t *_src_strides = ortho_src_strides.data();
467-
const py::ssize_t *_dst_strides = ortho_dst_strides.data();
468-
constexpr py::ssize_t _itemsize = 1; // in elements
469-
470-
constexpr bool is_c_contig = false;
471-
constexpr bool is_f_contig = false;
472462

473463
py::ssize_t ortho_src_offset(0);
474464
py::ssize_t ortho_dst_offset(0);
475465

476466
dpctl::tensor::py_internal::simplify_iteration_space(
477-
ortho_nd, _shape, _src_strides, _itemsize, is_c_contig, is_f_contig,
478-
_dst_strides, _itemsize, is_c_contig, is_f_contig,
467+
ortho_nd, _shape, ortho_src_strides, ortho_dst_strides,
468+
// output
479469
simplified_ortho_shape, simplified_ortho_src_strides,
480470
simplified_ortho_dst_strides, ortho_src_offset, ortho_dst_offset);
481471

@@ -775,19 +765,12 @@ py_place(dpctl::tensor::usm_ndarray dst,
775765
std::vector<py::ssize_t> simplified_ortho_rhs_strides;
776766

777767
const py::ssize_t *_shape = ortho_dst_shape.data();
778-
const py::ssize_t *_dst_strides = ortho_dst_strides.data();
779-
const py::ssize_t *_rhs_strides = ortho_rhs_strides.data();
780-
constexpr py::ssize_t _itemsize = 1; // in elements
781-
782-
constexpr bool is_c_contig = false;
783-
constexpr bool is_f_contig = false;
784768

785769
py::ssize_t ortho_dst_offset(0);
786770
py::ssize_t ortho_rhs_offset(0);
787771

788772
dpctl::tensor::py_internal::simplify_iteration_space(
789-
ortho_nd, _shape, _dst_strides, _itemsize, is_c_contig, is_f_contig,
790-
_rhs_strides, _itemsize, is_c_contig, is_f_contig,
773+
ortho_nd, _shape, ortho_dst_strides, ortho_rhs_strides,
791774
simplified_ortho_shape, simplified_ortho_dst_strides,
792775
simplified_ortho_rhs_strides, ortho_dst_offset, ortho_rhs_offset);
793776

dpctl/tensor/libtensor/source/copy_and_cast_usm_to_usm.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
167167
copy_ev);
168168
}
169169

170-
const py::ssize_t *src_strides = src.get_strides_raw();
171-
const py::ssize_t *dst_strides = dst.get_strides_raw();
170+
auto const &src_strides = src.get_strides_vector();
171+
auto const &dst_strides = dst.get_strides_vector();
172172

173173
using shT = std::vector<py::ssize_t>;
174174
shT simplified_shape;
@@ -180,25 +180,20 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
180180
int nd = src_nd;
181181
const py::ssize_t *shape = src_shape;
182182

183-
constexpr py::ssize_t src_itemsize = 1; // in elements
184-
constexpr py::ssize_t dst_itemsize = 1; // in elements
185-
186-
// all args except itemsizes and is_?_contig bools can be modified by
187-
// reference
183+
// nd, simplified_* and *_offset are modified by reference
188184
dpctl::tensor::py_internal::simplify_iteration_space(
189-
nd, shape, src_strides, src_itemsize, is_src_c_contig, is_src_f_contig,
190-
dst_strides, dst_itemsize, is_dst_c_contig, is_dst_f_contig,
185+
nd, shape, src_strides, dst_strides,
186+
// output
191187
simplified_shape, simplified_src_strides, simplified_dst_strides,
192188
src_offset, dst_offset);
193189

194190
if (nd < 2) {
195191
if (nd == 1) {
196-
std::array<py::ssize_t, 1> shape_arr = {shape[0]};
197-
// strides may be null
192+
std::array<py::ssize_t, 1> shape_arr = {simplified_shape[0]};
198193
std::array<py::ssize_t, 1> src_strides_arr = {
199-
(src_strides ? src_strides[0] : 1)};
194+
simplified_src_strides[0]};
200195
std::array<py::ssize_t, 1> dst_strides_arr = {
201-
(dst_strides ? dst_strides[0] : 1)};
196+
simplified_dst_strides[0]};
202197

203198
sycl::event copy_and_cast_1d_event;
204199
if ((src_strides_arr[0] == 1) && (dst_strides_arr[0] == 1) &&

dpctl/tensor/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424

2525
#include <CL/sycl.hpp>
26+
#include <algorithm>
2627
#include <vector>
2728

2829
#include "dpctl4pybind11.hpp"
@@ -143,10 +144,8 @@ void copy_numpy_ndarray_into_usm_ndarray(
143144
}
144145
}
145146

146-
const py::ssize_t *src_strides =
147-
npy_src.strides(); // N.B.: strides in bytes
148-
const py::ssize_t *dst_strides =
149-
dst.get_strides_raw(); // N.B.: strides in elements
147+
auto const &dst_strides =
148+
dst.get_strides_vector(); // N.B.: strides in elements
150149

151150
using shT = std::vector<py::ssize_t>;
152151
shT simplified_shape;
@@ -155,23 +154,42 @@ void copy_numpy_ndarray_into_usm_ndarray(
155154
py::ssize_t src_offset(0);
156155
py::ssize_t dst_offset(0);
157156

158-
py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes
159-
constexpr py::ssize_t dst_itemsize = 1; // item size in elements
160-
161157
int nd = src_ndim;
162158
const py::ssize_t *shape = src_shape;
163159

160+
const py::ssize_t *src_strides_p =
161+
npy_src.strides(); // N.B.: strides in bytes
162+
py::ssize_t src_itemsize = npy_src.itemsize(); // item size in bytes
163+
164164
bool is_src_c_contig = ((src_flags & py::array::c_style) != 0);
165165
bool is_src_f_contig = ((src_flags & py::array::f_style) != 0);
166166

167-
bool is_dst_c_contig = dst.is_c_contiguous();
168-
bool is_dst_f_contig = dst.is_f_contiguous();
167+
shT src_strides_in_elems;
168+
if (src_strides_p) {
169+
src_strides_in_elems.resize(nd);
170+
// copy and convert strides from bytes to elements
171+
std::transform(
172+
src_strides_p, src_strides_p + nd, std::begin(src_strides_in_elems),
173+
[src_itemsize](py::ssize_t el) { return el / src_itemsize; });
174+
}
175+
else {
176+
if (is_src_c_contig) {
177+
src_strides_in_elems =
178+
dpctl::tensor::c_contiguous_strides(nd, src_shape);
179+
}
180+
else if (is_src_f_contig) {
181+
src_strides_in_elems =
182+
dpctl::tensor::f_contiguous_strides(nd, src_shape);
183+
}
184+
else {
185+
throw py::value_error("NumPy source array has null strides but is "
186+
"neither C- nor F-contiguous.");
187+
}
188+
}
169189

170-
// all args except itemsizes and is_?_contig bools can be modified by
171-
// reference
172-
simplify_iteration_space(nd, shape, src_strides, src_itemsize,
173-
is_src_c_contig, is_src_f_contig, dst_strides,
174-
dst_itemsize, is_dst_c_contig, is_dst_f_contig,
190+
// nd, simplified_* vectors and offsets are modified by reference
191+
simplify_iteration_space(nd, shape, src_strides_in_elems, dst_strides,
192+
// outputs
175193
simplified_shape, simplified_src_strides,
176194
simplified_dst_strides, src_offset, dst_offset);
177195

@@ -186,18 +204,16 @@ void copy_numpy_ndarray_into_usm_ndarray(
186204
simplified_shape.push_back(1);
187205

188206
simplified_src_strides.reserve(nd);
189-
simplified_src_strides.push_back(src_itemsize);
207+
simplified_src_strides.push_back(1);
190208

191209
simplified_dst_strides.reserve(nd);
192-
simplified_dst_strides.push_back(dst_itemsize);
210+
simplified_dst_strides.push_back(1);
193211
}
194212

195213
// Minumum and maximum element offsets for source np.ndarray
196214
py::ssize_t npy_src_min_nelem_offset(0);
197215
py::ssize_t npy_src_max_nelem_offset(0);
198216
for (int i = 0; i < nd; ++i) {
199-
// convert source strides from bytes to elements
200-
simplified_src_strides[i] = simplified_src_strides[i] / src_itemsize;
201217
if (simplified_src_strides[i] < 0) {
202218
npy_src_min_nelem_offset +=
203219
simplified_src_strides[i] * (simplified_shape[i] - 1);

0 commit comments

Comments
 (0)