Skip to content

Commit a4a4016

Browse files
committed
PyArrayDescr::shape returns empty vec if scalar
1 parent ce897cb commit a4a4016

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/dtype.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,21 +190,22 @@ impl PyArrayDescr {
190190
}
191191
}
192192

193-
/// Returns shape tuple of the sub-array if this dtype is a sub-array, and `None` otherwise.
193+
/// Returns the shape of the sub-array.
194+
///
195+
/// If the dtype is not a sub-array, an empty vector is returned.
194196
///
195197
/// Equivalent to [`np.dtype.shape`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html).
196-
pub fn shape(&self) -> Option<Vec<usize>> {
198+
pub fn shape(&self) -> Vec<usize> {
197199
if !self.has_subarray() {
198-
return None;
199-
}
200-
Some(
200+
vec![]
201+
} else {
201202
// Panic-wise: numpy guarantees that shape is a tuple of non-negative integers
202203
unsafe {
203204
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
204205
}
205206
.extract()
206-
.unwrap(),
207-
)
207+
.unwrap()
208+
}
208209
}
209210

210211
/// Returns true if the dtype is a sub-array at the top level.
@@ -501,7 +502,7 @@ mod tests {
501502
assert!(!dt.has_subarray());
502503
assert!(dt.base().is_equiv_to(dt));
503504
assert_eq!(dt.ndim(), 0);
504-
assert_eq!(dt.shape(), None);
505+
assert_eq!(dt.shape(), vec![]);
505506
});
506507
}
507508

@@ -535,7 +536,7 @@ mod tests {
535536
assert!(!dt.is_aligned_struct());
536537
assert!(dt.has_subarray());
537538
assert_eq!(dt.ndim(), 2);
538-
assert_eq!(dt.shape().unwrap(), vec![2, 3]);
539+
assert_eq!(dt.shape(), vec![2, 3]);
539540
assert!(dt.base().is_equiv_to(dtype::<f64>(py)));
540541
});
541542
}
@@ -572,7 +573,7 @@ mod tests {
572573
assert!(dt.is_aligned_struct());
573574
assert!(!dt.has_subarray());
574575
assert_eq!(dt.ndim(), 0);
575-
assert_eq!(dt.shape(), None);
576+
assert_eq!(dt.shape(), vec![]);
576577
assert!(dt.base().is_equiv_to(dt));
577578
let x = dt.get_field("x").unwrap();
578579
assert!(x.0.is_equiv_to(dtype::<u8>(py)));

0 commit comments

Comments
 (0)