@@ -190,21 +190,22 @@ impl PyArrayDescr {
190
190
}
191
191
}
192
192
193
- /// Returns shape tuple of the sub-array if this dtype is a sub-array, and `None` otherwise.
193
+ /// Returns the shape of the sub-array.
194
+ ///
195
+ /// If the dtype is not a sub-array, an empty vector is returned.
194
196
///
195
197
/// Equivalent to [`np.dtype.shape`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html).
196
- pub fn shape ( & self ) -> Option < Vec < usize > > {
198
+ pub fn shape ( & self ) -> Vec < usize > {
197
199
if !self . has_subarray ( ) {
198
- return None ;
199
- }
200
- Some (
200
+ vec ! [ ]
201
+ } else {
201
202
// Panic-wise: numpy guarantees that shape is a tuple of non-negative integers
202
203
unsafe {
203
204
PyTuple :: from_borrowed_ptr ( self . py ( ) , ( * ( * self . as_dtype_ptr ( ) ) . subarray ) . shape )
204
205
}
205
206
. extract ( )
206
- . unwrap ( ) ,
207
- )
207
+ . unwrap ( )
208
+ }
208
209
}
209
210
210
211
/// Returns true if the dtype is a sub-array at the top level.
@@ -501,7 +502,7 @@ mod tests {
501
502
assert ! ( !dt. has_subarray( ) ) ;
502
503
assert ! ( dt. base( ) . is_equiv_to( dt) ) ;
503
504
assert_eq ! ( dt. ndim( ) , 0 ) ;
504
- assert_eq ! ( dt. shape( ) , None ) ;
505
+ assert_eq ! ( dt. shape( ) , vec! [ ] ) ;
505
506
} ) ;
506
507
}
507
508
@@ -535,7 +536,7 @@ mod tests {
535
536
assert ! ( !dt. is_aligned_struct( ) ) ;
536
537
assert ! ( dt. has_subarray( ) ) ;
537
538
assert_eq ! ( dt. ndim( ) , 2 ) ;
538
- assert_eq ! ( dt. shape( ) . unwrap ( ) , vec![ 2 , 3 ] ) ;
539
+ assert_eq ! ( dt. shape( ) , vec![ 2 , 3 ] ) ;
539
540
assert ! ( dt. base( ) . is_equiv_to( dtype:: <f64 >( py) ) ) ;
540
541
} ) ;
541
542
}
@@ -572,7 +573,7 @@ mod tests {
572
573
assert ! ( dt. is_aligned_struct( ) ) ;
573
574
assert ! ( !dt. has_subarray( ) ) ;
574
575
assert_eq ! ( dt. ndim( ) , 0 ) ;
575
- assert_eq ! ( dt. shape( ) , None ) ;
576
+ assert_eq ! ( dt. shape( ) , vec! [ ] ) ;
576
577
assert ! ( dt. base( ) . is_equiv_to( dt) ) ;
577
578
let x = dt. get_field ( "x" ) . unwrap ( ) ;
578
579
assert ! ( x. 0 . is_equiv_to( dtype:: <u8 >( py) ) ) ;
0 commit comments