@@ -134,7 +134,12 @@ where
134134 _ => {
135135 // if the array is not contiguous, copy all elements by `ArrayBase::iter`.
136136 let dim = self . raw_dim ( ) ;
137- let strides = NpyStrides :: from_dim ( & dim, mem:: size_of :: < A > ( ) ) ;
137+ let strides = NpyStrides :: new :: < _ , A > (
138+ dim. default_strides ( )
139+ . slice ( )
140+ . iter ( )
141+ . map ( |& x| x as npyffi:: npy_intp ) ,
142+ ) ;
138143 unsafe {
139144 let array = PyArray :: < A , _ > :: new_ ( py, dim, strides. as_ptr ( ) , 0 ) ;
140145 let data_ptr = array. data ( ) ;
@@ -173,10 +178,7 @@ where
173178 D : Dimension ,
174179{
175180 fn npy_strides ( & self ) -> NpyStrides {
176- NpyStrides :: new (
177- self . strides ( ) . iter ( ) . map ( |& x| x as npyffi:: npy_intp ) ,
178- mem:: size_of :: < A > ( ) ,
179- )
181+ NpyStrides :: new :: < _ , A > ( self . strides ( ) . iter ( ) . map ( |& x| x as npyffi:: npy_intp ) )
180182 }
181183
182184 fn order ( & self ) -> Option < Order > {
@@ -190,40 +192,27 @@ where
190192 }
191193}
192194
193- /// Numpy strides with short array optimization
194- pub ( crate ) enum NpyStrides {
195- Short ( [ npyffi:: npy_intp ; 8 ] ) ,
196- Long ( Vec < npyffi:: npy_intp > ) ,
197- }
195+ /// An array of strides sufficiently large for [any NumPy array][NPY_MAXDIMS]
196+ ///
197+ /// [NPY_MAXDIMS]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L40
198+ pub ( crate ) struct NpyStrides ( [ npyffi:: npy_intp ; 32 ] ) ;
198199
199200impl NpyStrides {
200201 pub ( crate ) fn as_ptr ( & self ) -> * const npy_intp {
201- match self {
202- NpyStrides :: Short ( inner) => inner. as_ptr ( ) ,
203- NpyStrides :: Long ( inner) => inner. as_ptr ( ) ,
204- }
202+ self . 0 . as_ptr ( )
205203 }
206- fn from_dim < D : Dimension > ( dim : & D , type_size : usize ) -> Self {
207- Self :: new (
208- dim. default_strides ( )
209- . slice ( )
210- . iter ( )
211- . map ( |& x| x as npyffi:: npy_intp ) ,
212- type_size,
213- )
214- }
215- fn new ( strides : impl ExactSizeIterator < Item = npyffi:: npy_intp > , type_size : usize ) -> Self {
216- let len = strides. len ( ) ;
217- let type_size = type_size as npyffi:: npy_intp ;
218- if len <= 8 {
219- let mut res = [ 0 ; 8 ] ;
220- for ( i, s) in strides. enumerate ( ) {
221- res[ i] = s * type_size;
222- }
223- NpyStrides :: Short ( res)
224- } else {
225- NpyStrides :: Long ( strides. map ( |n| n as npyffi:: npy_intp * type_size) . collect ( ) )
204+
205+ fn new < S , A > ( strides : S ) -> Self
206+ where
207+ S : Iterator < Item = npyffi:: npy_intp > ,
208+ {
209+ let type_size = mem:: size_of :: < A > ( ) as npyffi:: npy_intp ;
210+ let mut res = [ 0 ; 32 ] ;
211+ for ( i, s) in strides. enumerate ( ) {
212+ * res. get_mut ( i)
213+ . expect ( "Only dimensionalities of up to 32 are supported" ) = s * type_size;
226214 }
215+ Self ( res)
227216 }
228217}
229218
0 commit comments