@@ -328,8 +328,9 @@ class array : public buffer {
328
328
329
329
array () : array(0 , static_cast <const double *>(nullptr )) {}
330
330
331
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
332
- const std::vector<size_t > &strides, const void *ptr = nullptr ,
331
+ template <typename Shape, typename Strides>
332
+ array (const pybind11::dtype &dt, const Shape &shape,
333
+ const Strides &strides, const void *ptr = nullptr ,
333
334
handle base = handle()) {
334
335
auto & api = detail::npy_api::get ();
335
336
auto ndim = shape.size ();
@@ -362,7 +363,8 @@ class array : public buffer {
362
363
m_ptr = tmp.release ().ptr ();
363
364
}
364
365
365
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
366
+ template <typename ST>
367
+ array (const pybind11::dtype &dt, const std::disable_if<std::is_integral<ST>::value, ST>::type &shape,
366
368
const void *ptr = nullptr , handle base = handle())
367
369
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368
370
@@ -525,7 +527,7 @@ class array : public buffer {
525
527
throw std::runtime_error (" array is not writeable" );
526
528
}
527
529
528
- static std::vector<size_t > default_strides (const std::vector< size_t > & shape, size_t itemsize) {
530
+ template < typename ST> static std::vector<size_t > default_strides (const ST & shape, size_t itemsize) {
529
531
auto ndim = shape.size ();
530
532
std::vector<size_t > strides (ndim);
531
533
if (ndim) {
0 commit comments