Skip to content

Commit 59ad1e7

Browse files
ncullen93ncullen93NC CullenSkylion007pre-commit-ci[bot]
authored
reshape for numpy arrays (#984)
* reshape * more tests * Update numpy.h * Update test_numpy_array.py * Update numpy.h * Update numpy.h * Update test_numpy_array.cpp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge bug * Make clang-tidy happy * Add xfail for PyPy * Fix casting issue * Address reviews on additional tests * Fix ordering * Do a little more reordering * Fix typo * Try improving tests * Fix error in reshape * Add one more reshape test * streamlining new tests; removing a few stray msg Co-authored-by: ncullen93 <ncullen.th@dartmouth.edu> Co-authored-by: NC Cullen <nicholas.c.cullen.th@dartmouth.edu> Co-authored-by: Aaron Gokaslan <skylion.aaron@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf Grosse-Kunstleve <rwgk@google.com>
1 parent 031a700 commit 59ad1e7

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

include/pybind11/numpy.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ struct npy_api {
198198
// Unused. Not removed because that affects ABI of the class.
199199
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
200200
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
201+
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
202+
201203
private:
202204
enum functions {
203205
API_PyArray_GetNDArrayCFeatureVersion = 211,
@@ -212,10 +214,11 @@ struct npy_api {
212214
API_PyArray_NewCopy = 85,
213215
API_PyArray_NewFromDescr = 94,
214216
API_PyArray_DescrNewFromType = 96,
217+
API_PyArray_Newshape = 135,
218+
API_PyArray_Squeeze = 136,
215219
API_PyArray_DescrConverter = 174,
216220
API_PyArray_EquivTypes = 182,
217221
API_PyArray_GetArrayParamsFromObject = 278,
218-
API_PyArray_Squeeze = 136,
219222
API_PyArray_SetBaseObject = 282
220223
};
221224

@@ -243,11 +246,13 @@ struct npy_api {
243246
DECL_NPY_API(PyArray_NewCopy);
244247
DECL_NPY_API(PyArray_NewFromDescr);
245248
DECL_NPY_API(PyArray_DescrNewFromType);
249+
DECL_NPY_API(PyArray_Newshape);
250+
DECL_NPY_API(PyArray_Squeeze);
246251
DECL_NPY_API(PyArray_DescrConverter);
247252
DECL_NPY_API(PyArray_EquivTypes);
248253
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
249-
DECL_NPY_API(PyArray_Squeeze);
250254
DECL_NPY_API(PyArray_SetBaseObject);
255+
251256
#undef DECL_NPY_API
252257
return api;
253258
}
@@ -785,6 +790,18 @@ class array : public buffer {
785790
if (isinstance<array>(new_array)) { *this = std::move(new_array); }
786791
}
787792

793+
/// Optional `order` parameter omitted, to be added as needed.
794+
array reshape(ShapeContainer new_shape) {
795+
detail::npy_api::PyArray_Dims d
796+
= {reinterpret_cast<Py_intptr_t *>(new_shape->data()), int(new_shape->size())};
797+
auto new_array
798+
= reinterpret_steal<array>(detail::npy_api::get().PyArray_Newshape_(m_ptr, &d, 0));
799+
if (!new_array) {
800+
throw error_already_set();
801+
}
802+
return new_array;
803+
}
804+
788805
/// Ensure that the argument is a NumPy array
789806
/// In case of an error, nullptr is returned and the Python error is cleared.
790807
static array ensure(handle h, int ExtraFlags = 0) {

tests/test_numpy_array.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,13 @@ TEST_SUBMODULE(numpy_array, sm) {
405405
return a;
406406
});
407407

408+
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
409+
return a.reshape({N, M, O});
410+
});
411+
sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
412+
return a.reshape(new_shape);
413+
});
414+
408415
sm.def("index_using_ellipsis",
409416
[](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
410417

tests/test_numpy_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def test_array_unchecked_fixed_dims(msg):
411411
assert m.proxy_auxiliaries2_const_ref(z1)
412412

413413

414-
def test_array_unchecked_dyn_dims(msg):
414+
def test_array_unchecked_dyn_dims():
415415
z1 = np.array([[1, 2], [3, 4]], dtype="float64")
416416
m.proxy_add2_dyn(z1, 10)
417417
assert np.all(z1 == [[11, 12], [13, 14]])
@@ -444,7 +444,7 @@ def test_initializer_list():
444444
assert m.array_initializer_list4().shape == (1, 2, 3, 4)
445445

446446

447-
def test_array_resize(msg):
447+
def test_array_resize():
448448
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="float64")
449449
m.array_reshape2(a)
450450
assert a.size == 9
@@ -470,12 +470,37 @@ def test_array_resize(msg):
470470

471471

472472
@pytest.mark.xfail("env.PYPY")
473-
def test_array_create_and_resize(msg):
473+
def test_array_create_and_resize():
474474
a = m.create_and_resize(2)
475475
assert a.size == 4
476476
assert np.all(a == 42.0)
477477

478478

479+
def test_reshape_initializer_list():
480+
a = np.arange(2 * 7 * 3) + 1
481+
x = m.reshape_initializer_list(a, 2, 7, 3)
482+
assert x.shape == (2, 7, 3)
483+
assert list(x[1][4]) == [34, 35, 36]
484+
with pytest.raises(ValueError) as excinfo:
485+
m.reshape_initializer_list(a, 1, 7, 3)
486+
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (1,7,3)"
487+
488+
489+
def test_reshape_tuple():
490+
a = np.arange(3 * 7 * 2) + 1
491+
x = m.reshape_tuple(a, (3, 7, 2))
492+
assert x.shape == (3, 7, 2)
493+
assert list(x[1][4]) == [23, 24]
494+
y = m.reshape_tuple(x, (x.size,))
495+
assert y.shape == (42,)
496+
with pytest.raises(ValueError) as excinfo:
497+
m.reshape_tuple(a, (3, 7, 1))
498+
assert str(excinfo.value) == "cannot reshape array of size 42 into shape (3,7,1)"
499+
with pytest.raises(ValueError) as excinfo:
500+
m.reshape_tuple(a, ())
501+
assert str(excinfo.value) == "cannot reshape array of size 42 into shape ()"
502+
503+
479504
def test_index_using_ellipsis():
480505
a = m.index_using_ellipsis(np.zeros((5, 6, 7)))
481506
assert a.shape == (6,)

0 commit comments

Comments
 (0)