Skip to content

Commit

Permalink
next_array: Revise safety comments, Drop, and push
Browse files Browse the repository at this point in the history
  • Loading branch information
jswrenn committed Jul 5, 2024
1 parent 2c3af5c commit 9911308
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 29 deletions.
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1928,9 +1928,9 @@ pub trait Itertools: Iterator {
///
/// assert_eq!(Some([1, 2]), iter.next_array());
/// ```
fn next_array<T, const N: usize>(&mut self) -> Option<[T; N]>
fn next_array<const N: usize>(&mut self) -> Option<[Self::Item; N]>
where
Self: Sized + Iterator<Item = T>,
Self: Sized,
{
next_array::next_array(self)
}
Expand All @@ -1952,9 +1952,9 @@ pub trait Itertools: Iterator {
/// panic!("Expected two elements")
/// }
/// ```
fn collect_array<T, const N: usize>(mut self) -> Option<[T; N]>
fn collect_array<const N: usize>(mut self) -> Option<[Self::Item; N]>
where
Self: Sized + Iterator<Item = T>,
Self: Sized,

Check warning on line 1957 in src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/lib.rs#L1957

Added line #L1957 was not covered by tests
{
self.next_array().filter(|_| self.next().is_none())
}
Expand Down
89 changes: 67 additions & 22 deletions src/next_array.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use core::mem::{self, MaybeUninit};
use core::ptr;

/// An array of at most `N` elements.
struct ArrayBuilder<T, const N: usize> {
Expand All @@ -17,7 +16,7 @@ struct ArrayBuilder<T, const N: usize> {
impl<T, const N: usize> ArrayBuilder<T, N> {
/// Initializes a new, empty `ArrayBuilder`.
pub fn new() -> Self {
// SAFETY: the validity invariant trivially hold for a zero-length array.
// SAFETY: The safety invariant of `arr` trivially holds for `len = 0`.
Self {
arr: [(); N].map(|_| MaybeUninit::uninit()),
len: 0,
Expand All @@ -28,50 +27,96 @@ impl<T, const N: usize> ArrayBuilder<T, N> {
///
/// # Panics
///
/// This panics if `self.len() >= N`.
/// This panics if `self.len >= N` or if `self.len == usize::MAX`.
pub fn push(&mut self, value: T) {
// SAFETY: we maintain the invariant here that arr[..len] is valid.
// Indexing with self.len also ensures self.len < N, and thus <= N after
// the increment.
// PANICS: This will panic if `self.len >= N`.
// SAFETY: The safety invariant of `self.arr` applies to elements at
// indices `0..self.len` — not to the element at `self.len`. Writing to
// the element at index `self.len` therefore does not violate the safety
// invariant of `self.arr`. Even if this line panics, we have not
// created any intermediate invalid state.
self.arr[self.len] = MaybeUninit::new(value);
self.len += 1;
// PANICS: This will panic if `self.len == usize::MAX`.
// SAFETY: By invariant on `self.arr`, all elements at indicies
// `0..self.len` are valid. Due to the above write, the element at
// `self.len` is now also valid. Consequently, all elements at indicies
// `0..(self.len + 1)` are valid, and `self.len` can be safely
// incremented without violating `self.arr`'s invariant. It is fine if
// this increment panics, as we have not created any intermediate
// invalid state.
self.len = match self.len.checked_add(1) {
Some(sum) => sum,
None => panic!("`self.len == usize::MAX`; cannot increment `len`"),

Check warning on line 49 in src/next_array.rs

View check run for this annotation

Codecov / codecov/patch

src/next_array.rs#L49

Added line #L49 was not covered by tests
};
}

/// Consumes the elements in the `ArrayBuilder` and returns them as an array `[T; N]`.
/// Consumes the elements in the `ArrayBuilder` and returns them as an array
/// `[T; N]`.
///
/// If `self.len() < N`, this returns `None`.
pub fn take(&mut self) -> Option<[T; N]> {
if self.len == N {
// Take the array, resetting our length back to zero.
// SAFETY: Decreasing the value of `self.len` cannot violate the
// safety invariant on `self.arr`.
self.len = 0;

// SAFETY: Since `self.len` is 0, `self.arr` may safely contain
// uninitialized elements.
let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit()));

// SAFETY: we had len == N, so all elements in arr are valid.
Some(unsafe { arr.map(|v| v.assume_init()) })
Some(arr.map(|v| {
// SAFETY: We know that all elements of `arr` are valid because
// we checked that `len == N`.
unsafe { v.assume_init() }
}))
} else {
None

Check warning on line 73 in src/next_array.rs

View check run for this annotation

Codecov / codecov/patch

src/next_array.rs#L73

Added line #L73 was not covered by tests
}
}
}

impl<T, const N: usize> AsMut<[T]> for ArrayBuilder<T, N> {
fn as_mut(&mut self) -> &mut [T] {
let valid = &mut self.arr[..self.len];
// SAFETY: By invariant on `self.arr`, the elements of `self.arr` at
// indices `0..self.len` are in a valid state. Since `valid` references
// only these elements, the safety precondition of
// `slice_assume_init_mut` is satisfied.
unsafe { slice_assume_init_mut(valid) }
}
}

impl<T, const N: usize> Drop for ArrayBuilder<T, N> {
// We provide a non-trivial `Drop` impl, because the trivial impl would be a
// no-op; `MaybeUninit<T>` has no innate awareness of its own validity, and
// so it can only forget its contents. By leveraging the safety invariant of
// `self.arr`, we do know which elements of `self.arr` are valid, and can
// selectively run their destructors.
fn drop(&mut self) {
unsafe {
// SAFETY: arr[..len] is valid, so must be dropped. First we create
// a pointer to this valid slice, then drop that slice in-place.
// The cast from *mut MaybeUninit<T> to *mut T is always sound by
// the layout guarantees of MaybeUninit.
let ptr_to_first: *mut MaybeUninit<T> = self.arr.as_mut_ptr();
let ptr_to_slice = ptr::slice_from_raw_parts_mut(ptr_to_first.cast::<T>(), self.len);
ptr::drop_in_place(ptr_to_slice);
}
let valid = self.as_mut();
// SAFETY: TODO
unsafe { core::ptr::drop_in_place(valid) }
}
}

/// Assuming all the elements are initialized, get a mutable slice to them.
///
/// # Safety
///
/// The caller guarantees that the elements `T` referenced by `slice` are in a
/// valid state.
unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
// SAFETY: Casting `&mut [MaybeUninit<T>]` to `&mut [T]` is sound, because
// `MaybeUninit<T>` is guaranteed to have the same size, alignment and ABI
// as `T`, and because the caller has guaranteed that `slice` is in the
// valid state.
unsafe { &mut *(slice as *mut [MaybeUninit<T>] as *mut [T]) }
}

/// Equivalent to `it.next_array()`.
pub fn next_array<I, T, const N: usize>(it: &mut I) -> Option<[T; N]>
pub fn next_array<I, const N: usize>(it: &mut I) -> Option<[I::Item; N]>
where
I: Iterator<Item = T>,
I: Iterator,
{
let mut builder = ArrayBuilder::new();
for _ in 0..N {
Expand Down
6 changes: 3 additions & 3 deletions tests/test_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ fn next_array() {
assert_eq!(iter.next_array(), Some([]));
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2]));
assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4]));
assert_eq!(iter.next_array::<_, 2>(), None);
assert_eq!(iter.next_array::<2>(), None);
}

#[test]
Expand All @@ -391,9 +391,9 @@ fn collect_array() {

let v = [1];
let iter = v.iter().cloned();
assert_eq!(iter.collect_array::<_, 2>(), None);
assert_eq!(iter.collect_array::<2>(), None);

let v = [1, 2, 3];
let iter = v.iter().cloned();
assert_eq!(iter.collect_array::<_, 2>(), None);
assert_eq!(iter.collect_array::<2>(), None);
}

0 comments on commit 9911308

Please sign in to comment.