Skip to content

Commit acbcdeb

Browse files
committed
Allow any type of sequence in template ctor
1 parent 2723a38 commit acbcdeb

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

include/pybind11/numpy.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,9 @@ class array : public buffer {
328328

329329
array() : array(0, static_cast<const double *>(nullptr)) {}
330330

331-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
332-
const std::vector<size_t> &strides, const void *ptr = nullptr,
331+
template <typename Shape, typename Strides>
332+
array(const pybind11::dtype &dt, const Shape &shape,
333+
const Strides &strides, const void *ptr = nullptr,
333334
handle base = handle()) {
334335
auto& api = detail::npy_api::get();
335336
auto ndim = shape.size();
@@ -362,7 +363,8 @@ class array : public buffer {
362363
m_ptr = tmp.release().ptr();
363364
}
364365

365-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
366+
template <typename ST>
367+
array(const pybind11::dtype &dt, const std::disable_if<std::is_integral<ST>::value, ST>::type &shape,
366368
const void *ptr = nullptr, handle base = handle())
367369
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368370

@@ -525,7 +527,7 @@ class array : public buffer {
525527
throw std::runtime_error("array is not writeable");
526528
}
527529

528-
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
530+
template <typename ST> static std::vector<size_t> default_strides(const ST& shape, size_t itemsize) {
529531
auto ndim = shape.size();
530532
std::vector<size_t> strides(ndim);
531533
if (ndim) {

0 commit comments

Comments
 (0)