Skip to content

Commit 2280ae1

Browse files
Merge pull request #1198 from IntelPython/fix-gh-1196-and-1197
2 parents 2da47aa + c5687db commit 2280ae1

File tree

5 files changed

+224
-16
lines changed

5 files changed

+224
-16
lines changed

dpctl/tensor/_reshape.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
import dpctl.tensor as dpt
2121
from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
22-
from dpctl.tensor._tensor_impl import _copy_usm_ndarray_for_reshape
22+
from dpctl.tensor._tensor_impl import (
23+
_copy_usm_ndarray_for_reshape,
24+
_ravel_multi_index,
25+
_unravel_index,
26+
)
2327

2428
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
2529

@@ -36,6 +40,14 @@ def _make_unit_indexes(shape):
3640
return mi
3741

3842

43+
def ti_unravel_index(flat_index, shape, order="C"):
44+
return _unravel_index(flat_index, shape, order)
45+
46+
47+
def ti_ravel_multi_index(multi_index, shape, order="C"):
48+
return _ravel_multi_index(multi_index, shape, order)
49+
50+
3951
def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
4052
"""
4153
When reshaping array with `old_sh` shape and `old_sts` strides
@@ -47,11 +59,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
4759
sum(
4860
st_i * ind_i
4961
for st_i, ind_i in zip(
50-
old_sts, np.unravel_index(flat_index, old_sh, order=order)
62+
old_sts, ti_unravel_index(flat_index, old_sh, order=order)
5163
)
5264
)
5365
for flat_index in [
54-
np.ravel_multi_index(unitvec, new_sh, order=order)
66+
ti_ravel_multi_index(unitvec, new_sh, order=order)
5567
for unitvec in eye_new_mi
5668
]
5769
]
@@ -60,11 +72,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
6072
sum(
6173
st_i * ind_i
6274
for st_i, ind_i in zip(
63-
new_sts, np.unravel_index(flat_index, new_sh, order=order)
75+
new_sts, ti_unravel_index(flat_index, new_sh, order=order)
6476
)
6577
)
6678
for flat_index in [
67-
np.ravel_multi_index(unitvec, old_sh, order=order)
79+
ti_ravel_multi_index(unitvec, old_sh, order=order)
6880
for unitvec in eye_old_mi
6981
]
7082
]
@@ -123,7 +135,13 @@ def reshape(X, shape, order="C", copy=None):
123135
"value which can only be -1"
124136
)
125137
if negative_ones_count:
126-
v = X.size // (-np.prod(shape))
138+
sz = -np.prod(shape)
139+
if sz == 0:
140+
raise ValueError(
141+
f"Can not reshape array of size {X.size} into "
142+
f"shape {tuple(i for i in shape if i >= 0)}"
143+
)
144+
v = X.size // sz
127145
shape = [v if d == -1 else d for d in shape]
128146
if X.size != np.prod(shape):
129147
raise ValueError(f"Can not reshape into {shape}")

dpctl/tensor/libtensor/source/simplify_iteration_space.cpp

Lines changed: 140 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,17 @@ void simplify_iteration_space_1(int &nd,
7171
nd = contracted_nd;
7272
}
7373
else if (nd == 1) {
74+
offset = 0;
7475
// Populate vectors
7576
simplified_shape.reserve(nd);
7677
simplified_shape.push_back(shape[0]);
7778

7879
simplified_strides.reserve(nd);
79-
simplified_strides.push_back(strides[0]);
80+
simplified_strides.push_back((strides[0] >= 0) ? strides[0]
81+
: -strides[0]);
82+
if ((strides[0] < 0) && (shape[0] > 1)) {
83+
offset += (shape[0] - 1) * strides[0];
84+
}
8085

8186
assert(simplified_shape.size() == static_cast<size_t>(nd));
8287
assert(simplified_strides.size() == static_cast<size_t>(nd));
@@ -128,17 +133,27 @@ void simplify_iteration_space(int &nd,
128133
nd = contracted_nd;
129134
}
130135
else if (nd == 1) {
136+
src_offset = 0;
137+
dst_offset = 0;
131138
// Populate vectors
132139
simplified_shape.reserve(nd);
133140
simplified_shape.push_back(shape[0]);
134141
assert(simplified_shape.size() == static_cast<size_t>(nd));
135142

136143
simplified_src_strides.reserve(nd);
137-
simplified_src_strides.push_back(src_strides[0]);
144+
simplified_src_strides.push_back(
145+
(src_strides[0] >= 0) ? src_strides[0] : -src_strides[0]);
146+
if ((src_strides[0] < 0) && (shape[0] > 1)) {
147+
src_offset += (shape[0] - 1) * src_strides[0];
148+
}
138149
assert(simplified_src_strides.size() == static_cast<size_t>(nd));
139150

140151
simplified_dst_strides.reserve(nd);
141-
simplified_dst_strides.push_back(dst_strides[0]);
152+
simplified_dst_strides.push_back(
153+
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
154+
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
155+
dst_offset += (shape[0] - 1) * dst_strides[0];
156+
}
142157
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
143158
}
144159
}
@@ -202,21 +217,36 @@ void simplify_iteration_space_3(
202217
nd = contracted_nd;
203218
}
204219
else if (nd == 1) {
220+
src1_offset = 0;
221+
src2_offset = 0;
222+
dst_offset = 0;
205223
// Populate vectors
206224
simplified_shape.reserve(nd);
207225
simplified_shape.push_back(shape[0]);
208226
assert(simplified_shape.size() == static_cast<size_t>(nd));
209227

210228
simplified_src1_strides.reserve(nd);
211-
simplified_src1_strides.push_back(src1_strides[0]);
229+
simplified_src1_strides.push_back(
230+
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
231+
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
232+
src1_offset += src1_strides[0] * (shape[0] - 1);
233+
}
212234
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
213235

214236
simplified_src2_strides.reserve(nd);
215-
simplified_src2_strides.push_back(src2_strides[0]);
237+
simplified_src2_strides.push_back(
238+
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
239+
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
240+
src2_offset += src2_strides[0] * (shape[0] - 1);
241+
}
216242
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
217243

218244
simplified_dst_strides.reserve(nd);
219-
simplified_dst_strides.push_back(dst_strides[0]);
245+
simplified_dst_strides.push_back(
246+
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
247+
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
248+
dst_offset += dst_strides[0] * (shape[0] - 1);
249+
}
220250
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
221251
}
222252
}
@@ -293,29 +323,129 @@ void simplify_iteration_space_4(
293323
nd = contracted_nd;
294324
}
295325
else if (nd == 1) {
326+
src1_offset = 0;
327+
src2_offset = 0;
328+
src3_offset = 0;
329+
dst_offset = 0;
296330
// Populate vectors
297331
simplified_shape.reserve(nd);
298332
simplified_shape.push_back(shape[0]);
299333
assert(simplified_shape.size() == static_cast<size_t>(nd));
300334

301335
simplified_src1_strides.reserve(nd);
302-
simplified_src1_strides.push_back(src1_strides[0]);
336+
simplified_src1_strides.push_back(
337+
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
338+
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
339+
src1_offset += src1_strides[0] * (shape[0] - 1);
340+
}
303341
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));
304342

305343
simplified_src2_strides.reserve(nd);
306-
simplified_src2_strides.push_back(src2_strides[0]);
344+
simplified_src2_strides.push_back(
345+
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
346+
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
347+
src2_offset += src2_strides[0] * (shape[0] - 1);
348+
}
307349
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));
308350

309351
simplified_src3_strides.reserve(nd);
310-
simplified_src3_strides.push_back(src3_strides[0]);
352+
simplified_src3_strides.push_back(
353+
(src3_strides[0] >= 0) ? src3_strides[0] : -src3_strides[0]);
354+
if ((src3_strides[0] < 0) && (shape[0] > 1)) {
355+
src3_offset += src3_strides[0] * (shape[0] - 1);
356+
}
311357
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));
312358

313359
simplified_dst_strides.reserve(nd);
314-
simplified_dst_strides.push_back(dst_strides[0]);
360+
simplified_dst_strides.push_back(
361+
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
362+
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
363+
dst_offset += dst_strides[0] * (shape[0] - 1);
364+
}
315365
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
316366
}
317367
}
318368

369+
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &mi,
370+
std::vector<py::ssize_t> const &shape)
371+
{
372+
size_t nd = shape.size();
373+
if (nd != mi.size()) {
374+
throw py::value_error(
375+
"Multi-index and shape vectors must have the same length.");
376+
}
377+
378+
py::ssize_t flat_index = 0;
379+
py::ssize_t s = 1;
380+
for (size_t i = 0; i < nd; ++i) {
381+
flat_index += mi.at(nd - 1 - i) * s;
382+
s *= shape.at(nd - 1 - i);
383+
}
384+
385+
return flat_index;
386+
}
387+
388+
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &mi,
389+
std::vector<py::ssize_t> const &shape)
390+
{
391+
size_t nd = shape.size();
392+
if (nd != mi.size()) {
393+
throw py::value_error(
394+
"Multi-index and shape vectors must have the same length.");
395+
}
396+
397+
py::ssize_t flat_index = 0;
398+
py::ssize_t s = 1;
399+
for (size_t i = 0; i < nd; ++i) {
400+
flat_index += mi.at(i) * s;
401+
s *= shape.at(i);
402+
}
403+
404+
return flat_index;
405+
}
406+
407+
std::vector<py::ssize_t> _unravel_index_c(py::ssize_t flat_index,
408+
std::vector<py::ssize_t> const &shape)
409+
{
410+
size_t nd = shape.size();
411+
std::vector<py::ssize_t> mi;
412+
mi.resize(nd);
413+
414+
py::ssize_t i_ = flat_index;
415+
for (size_t dim = 0; dim + 1 < nd; ++dim) {
416+
const py::ssize_t si = shape[nd - 1 - dim];
417+
const py::ssize_t q = i_ / si;
418+
const py::ssize_t r = (i_ - q * si);
419+
mi[nd - 1 - dim] = r;
420+
i_ = q;
421+
}
422+
if (nd) {
423+
mi[0] = i_;
424+
}
425+
return mi;
426+
}
427+
428+
std::vector<py::ssize_t> _unravel_index_f(py::ssize_t flat_index,
429+
std::vector<py::ssize_t> const &shape)
430+
{
431+
size_t nd = shape.size();
432+
std::vector<py::ssize_t> mi;
433+
mi.resize(nd);
434+
435+
py::ssize_t i_ = flat_index;
436+
for (size_t dim = 0; dim + 1 < nd; ++dim) {
437+
const py::ssize_t si = shape[dim];
438+
const py::ssize_t q = i_ / si;
439+
const py::ssize_t r = (i_ - q * si);
440+
mi[dim] = r;
441+
i_ = q;
442+
}
443+
if (nd) {
444+
mi[nd - 1] = i_;
445+
}
446+
return mi;
447+
}
448+
319449
} // namespace py_internal
320450
} // namespace tensor
321451
} // namespace dpctl

dpctl/tensor/libtensor/source/simplify_iteration_space.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ void simplify_iteration_space_4(int &,
9090
py::ssize_t &,
9191
py::ssize_t &);
9292

93+
py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &,
94+
std::vector<py::ssize_t> const &);
95+
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &,
96+
std::vector<py::ssize_t> const &);
97+
std::vector<py::ssize_t> _unravel_index_c(py::ssize_t,
98+
std::vector<py::ssize_t> const &);
99+
std::vector<py::ssize_t> _unravel_index_f(py::ssize_t,
100+
std::vector<py::ssize_t> const &);
93101
} // namespace py_internal
94102
} // namespace tensor
95103
} // namespace dpctl

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "full_ctor.hpp"
4343
#include "integer_advanced_indexing.hpp"
4444
#include "linear_sequences.hpp"
45+
#include "simplify_iteration_space.hpp"
4546
#include "triul_ctor.hpp"
4647
#include "utils/memory_overlap.hpp"
4748
#include "utils/strided_iters.hpp"
@@ -182,6 +183,37 @@ PYBIND11_MODULE(_tensor_impl, m)
182183
"as the original "
183184
"iterator, possibly in a different order.");
184185

186+
static constexpr char orderC = 'C';
187+
m.def(
188+
"_ravel_multi_index",
189+
[](const std::vector<py::ssize_t> &mi,
190+
const std::vector<py::ssize_t> &shape, char order = 'C') {
191+
if (order == orderC) {
192+
return dpctl::tensor::py_internal::_ravel_multi_index_c(mi,
193+
shape);
194+
}
195+
else {
196+
return dpctl::tensor::py_internal::_ravel_multi_index_f(mi,
197+
shape);
198+
}
199+
},
200+
"");
201+
202+
m.def(
203+
"_unravel_index",
204+
[](py::ssize_t flat_index, const std::vector<py::ssize_t> &shape,
205+
char order = 'C') {
206+
if (order == orderC) {
207+
return dpctl::tensor::py_internal::_unravel_index_c(flat_index,
208+
shape);
209+
}
210+
else {
211+
return dpctl::tensor::py_internal::_unravel_index_f(flat_index,
212+
shape);
213+
}
214+
},
215+
"");
216+
185217
m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
186218
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
187219
"number of elements using underlying 'C'-contiguous order for flat "

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,26 @@ def test_reshape():
12931293
assert A4.shape == requested_shape
12941294

12951295

1296+
def test_reshape_zero_size():
1297+
try:
1298+
a = dpt.empty((0,))
1299+
except dpctl.SyclDeviceCreationError:
1300+
pytest.skip("No SYCL devices available")
1301+
with pytest.raises(ValueError):
1302+
dpt.reshape(a, (-1, 0))
1303+
1304+
1305+
def test_reshape_large_ndim():
1306+
ndim = 32
1307+
idx = tuple(1 if i + 1 < ndim else ndim for i in range(ndim))
1308+
try:
1309+
d = dpt.ones(ndim, dtype="i4")
1310+
except dpctl.SyclDeviceCreationError:
1311+
pytest.skip("No SYCL devices available")
1312+
d = dpt.reshape(d, idx)
1313+
assert d.shape == idx
1314+
1315+
12961316
def test_reshape_copy_kwrd():
12971317
try:
12981318
X = dpt.usm_ndarray((2, 3), "i4")

0 commit comments

Comments
 (0)