Skip to content

Fix for reshape-related bugs #1198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import dpctl.tensor as dpt
from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
from dpctl.tensor._tensor_impl import _copy_usm_ndarray_for_reshape
from dpctl.tensor._tensor_impl import (
_copy_usm_ndarray_for_reshape,
_ravel_multi_index,
_unravel_index,
)

__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."

Expand All @@ -36,6 +40,14 @@ def _make_unit_indexes(shape):
return mi


def ti_unravel_index(flat_index, shape, order="C"):
return _unravel_index(flat_index, shape, order)


def ti_ravel_multi_index(multi_index, shape, order="C"):
return _ravel_multi_index(multi_index, shape, order)


def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
"""
When reshaping array with `old_sh` shape and `old_sts` strides
Expand All @@ -47,11 +59,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
sum(
st_i * ind_i
for st_i, ind_i in zip(
old_sts, np.unravel_index(flat_index, old_sh, order=order)
old_sts, ti_unravel_index(flat_index, old_sh, order=order)
)
)
for flat_index in [
np.ravel_multi_index(unitvec, new_sh, order=order)
ti_ravel_multi_index(unitvec, new_sh, order=order)
for unitvec in eye_new_mi
]
]
Expand All @@ -60,11 +72,11 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
sum(
st_i * ind_i
for st_i, ind_i in zip(
new_sts, np.unravel_index(flat_index, new_sh, order=order)
new_sts, ti_unravel_index(flat_index, new_sh, order=order)
)
)
for flat_index in [
np.ravel_multi_index(unitvec, old_sh, order=order)
ti_ravel_multi_index(unitvec, old_sh, order=order)
for unitvec in eye_old_mi
]
]
Expand Down Expand Up @@ -123,7 +135,13 @@ def reshape(X, shape, order="C", copy=None):
"value which can only be -1"
)
if negative_ones_count:
v = X.size // (-np.prod(shape))
sz = -np.prod(shape)
if sz == 0:
raise ValueError(
f"Can not reshape array of size {X.size} into "
f"shape {tuple(i for i in shape if i >= 0)}"
)
v = X.size // sz
shape = [v if d == -1 else d for d in shape]
if X.size != np.prod(shape):
raise ValueError(f"Can not reshape into {shape}")
Expand Down
150 changes: 140 additions & 10 deletions dpctl/tensor/libtensor/source/simplify_iteration_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,17 @@ void simplify_iteration_space_1(int &nd,
nd = contracted_nd;
}
else if (nd == 1) {
offset = 0;
// Populate vectors
simplified_shape.reserve(nd);
simplified_shape.push_back(shape[0]);

simplified_strides.reserve(nd);
simplified_strides.push_back(strides[0]);
simplified_strides.push_back((strides[0] >= 0) ? strides[0]
: -strides[0]);
if ((strides[0] < 0) && (shape[0] > 1)) {
offset += (shape[0] - 1) * strides[0];
}

assert(simplified_shape.size() == static_cast<size_t>(nd));
assert(simplified_strides.size() == static_cast<size_t>(nd));
Expand Down Expand Up @@ -128,17 +133,27 @@ void simplify_iteration_space(int &nd,
nd = contracted_nd;
}
else if (nd == 1) {
src_offset = 0;
dst_offset = 0;
// Populate vectors
simplified_shape.reserve(nd);
simplified_shape.push_back(shape[0]);
assert(simplified_shape.size() == static_cast<size_t>(nd));

simplified_src_strides.reserve(nd);
simplified_src_strides.push_back(src_strides[0]);
simplified_src_strides.push_back(
(src_strides[0] >= 0) ? src_strides[0] : -src_strides[0]);
if ((src_strides[0] < 0) && (shape[0] > 1)) {
src_offset += (shape[0] - 1) * src_strides[0];
}
assert(simplified_src_strides.size() == static_cast<size_t>(nd));

simplified_dst_strides.reserve(nd);
simplified_dst_strides.push_back(dst_strides[0]);
simplified_dst_strides.push_back(
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
dst_offset += (shape[0] - 1) * dst_strides[0];
}
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
}
}
Expand Down Expand Up @@ -202,21 +217,36 @@ void simplify_iteration_space_3(
nd = contracted_nd;
}
else if (nd == 1) {
src1_offset = 0;
src2_offset = 0;
dst_offset = 0;
// Populate vectors
simplified_shape.reserve(nd);
simplified_shape.push_back(shape[0]);
assert(simplified_shape.size() == static_cast<size_t>(nd));

simplified_src1_strides.reserve(nd);
simplified_src1_strides.push_back(src1_strides[0]);
simplified_src1_strides.push_back(
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
src1_offset += src1_strides[0] * (shape[0] - 1);
}
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));

simplified_src2_strides.reserve(nd);
simplified_src2_strides.push_back(src2_strides[0]);
simplified_src2_strides.push_back(
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
src2_offset += src2_strides[0] * (shape[0] - 1);
}
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));

simplified_dst_strides.reserve(nd);
simplified_dst_strides.push_back(dst_strides[0]);
simplified_dst_strides.push_back(
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
dst_offset += dst_strides[0] * (shape[0] - 1);
}
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
}
}
Expand Down Expand Up @@ -293,29 +323,129 @@ void simplify_iteration_space_4(
nd = contracted_nd;
}
else if (nd == 1) {
src1_offset = 0;
src2_offset = 0;
src3_offset = 0;
dst_offset = 0;
// Populate vectors
simplified_shape.reserve(nd);
simplified_shape.push_back(shape[0]);
assert(simplified_shape.size() == static_cast<size_t>(nd));

simplified_src1_strides.reserve(nd);
simplified_src1_strides.push_back(src1_strides[0]);
simplified_src1_strides.push_back(
(src1_strides[0] >= 0) ? src1_strides[0] : -src1_strides[0]);
if ((src1_strides[0] < 0) && (shape[0] > 1)) {
src1_offset += src1_strides[0] * (shape[0] - 1);
}
assert(simplified_src1_strides.size() == static_cast<size_t>(nd));

simplified_src2_strides.reserve(nd);
simplified_src2_strides.push_back(src2_strides[0]);
simplified_src2_strides.push_back(
(src2_strides[0] >= 0) ? src2_strides[0] : -src2_strides[0]);
if ((src2_strides[0] < 0) && (shape[0] > 1)) {
src2_offset += src2_strides[0] * (shape[0] - 1);
}
assert(simplified_src2_strides.size() == static_cast<size_t>(nd));

simplified_src3_strides.reserve(nd);
simplified_src3_strides.push_back(src3_strides[0]);
simplified_src3_strides.push_back(
(src3_strides[0] >= 0) ? src3_strides[0] : -src3_strides[0]);
if ((src3_strides[0] < 0) && (shape[0] > 1)) {
src3_offset += src3_strides[0] * (shape[0] - 1);
}
assert(simplified_src3_strides.size() == static_cast<size_t>(nd));

simplified_dst_strides.reserve(nd);
simplified_dst_strides.push_back(dst_strides[0]);
simplified_dst_strides.push_back(
(dst_strides[0] >= 0) ? dst_strides[0] : -dst_strides[0]);
if ((dst_strides[0] < 0) && (shape[0] > 1)) {
dst_offset += dst_strides[0] * (shape[0] - 1);
}
assert(simplified_dst_strides.size() == static_cast<size_t>(nd));
}
}

py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &mi,
std::vector<py::ssize_t> const &shape)
{
size_t nd = shape.size();
if (nd != mi.size()) {
throw py::value_error(
"Multi-index and shape vectors must have the same length.");
}

py::ssize_t flat_index = 0;
py::ssize_t s = 1;
for (size_t i = 0; i < nd; ++i) {
flat_index += mi.at(nd - 1 - i) * s;
s *= shape.at(nd - 1 - i);
}

return flat_index;
}

py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &mi,
std::vector<py::ssize_t> const &shape)
{
size_t nd = shape.size();
if (nd != mi.size()) {
throw py::value_error(
"Multi-index and shape vectors must have the same length.");
}

py::ssize_t flat_index = 0;
py::ssize_t s = 1;
for (size_t i = 0; i < nd; ++i) {
flat_index += mi.at(i) * s;
s *= shape.at(i);
}

return flat_index;
}

std::vector<py::ssize_t> _unravel_index_c(py::ssize_t flat_index,
std::vector<py::ssize_t> const &shape)
{
size_t nd = shape.size();
std::vector<py::ssize_t> mi;
mi.resize(nd);

py::ssize_t i_ = flat_index;
for (size_t dim = 0; dim + 1 < nd; ++dim) {
const py::ssize_t si = shape[nd - 1 - dim];
const py::ssize_t q = i_ / si;
const py::ssize_t r = (i_ - q * si);
mi[nd - 1 - dim] = r;
i_ = q;
}
if (nd) {
mi[0] = i_;
}
return mi;
}

std::vector<py::ssize_t> _unravel_index_f(py::ssize_t flat_index,
std::vector<py::ssize_t> const &shape)
{
size_t nd = shape.size();
std::vector<py::ssize_t> mi;
mi.resize(nd);

py::ssize_t i_ = flat_index;
for (size_t dim = 0; dim + 1 < nd; ++dim) {
const py::ssize_t si = shape[dim];
const py::ssize_t q = i_ / si;
const py::ssize_t r = (i_ - q * si);
mi[dim] = r;
i_ = q;
}
if (nd) {
mi[nd - 1] = i_;
}
return mi;
}

} // namespace py_internal
} // namespace tensor
} // namespace dpctl
8 changes: 8 additions & 0 deletions dpctl/tensor/libtensor/source/simplify_iteration_space.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ void simplify_iteration_space_4(int &,
py::ssize_t &,
py::ssize_t &);

py::ssize_t _ravel_multi_index_c(std::vector<py::ssize_t> const &,
std::vector<py::ssize_t> const &);
py::ssize_t _ravel_multi_index_f(std::vector<py::ssize_t> const &,
std::vector<py::ssize_t> const &);
std::vector<py::ssize_t> _unravel_index_c(py::ssize_t,
std::vector<py::ssize_t> const &);
std::vector<py::ssize_t> _unravel_index_f(py::ssize_t,
std::vector<py::ssize_t> const &);
} // namespace py_internal
} // namespace tensor
} // namespace dpctl
32 changes: 32 additions & 0 deletions dpctl/tensor/libtensor/source/tensor_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "full_ctor.hpp"
#include "integer_advanced_indexing.hpp"
#include "linear_sequences.hpp"
#include "simplify_iteration_space.hpp"
#include "triul_ctor.hpp"
#include "utils/memory_overlap.hpp"
#include "utils/strided_iters.hpp"
Expand Down Expand Up @@ -182,6 +183,37 @@ PYBIND11_MODULE(_tensor_impl, m)
"as the original "
"iterator, possibly in a different order.");

static constexpr char orderC = 'C';
m.def(
"_ravel_multi_index",
[](const std::vector<py::ssize_t> &mi,
const std::vector<py::ssize_t> &shape, char order = 'C') {
if (order == orderC) {
return dpctl::tensor::py_internal::_ravel_multi_index_c(mi,
shape);
}
else {
return dpctl::tensor::py_internal::_ravel_multi_index_f(mi,
shape);
}
},
"");

m.def(
"_unravel_index",
[](py::ssize_t flat_index, const std::vector<py::ssize_t> &shape,
char order = 'C') {
if (order == orderC) {
return dpctl::tensor::py_internal::_unravel_index_c(flat_index,
shape);
}
else {
return dpctl::tensor::py_internal::_unravel_index_f(flat_index,
shape);
}
},
"");

m.def("_copy_usm_ndarray_for_reshape", &copy_usm_ndarray_for_reshape,
"Copies from usm_ndarray `src` into usm_ndarray `dst` with the same "
"number of elements using underlying 'C'-contiguous order for flat "
Expand Down
20 changes: 20 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,26 @@ def test_reshape():
assert A4.shape == requested_shape


def test_reshape_zero_size():
try:
a = dpt.empty((0,))
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
with pytest.raises(ValueError):
dpt.reshape(a, (-1, 0))


def test_reshape_large_ndim():
ndim = 32
idx = tuple(1 if i + 1 < ndim else ndim for i in range(ndim))
try:
d = dpt.ones(ndim, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
d = dpt.reshape(d, idx)
assert d.shape == idx


def test_reshape_copy_kwrd():
try:
X = dpt.usm_ndarray((2, 3), "i4")
Expand Down