Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safer implementation of RepeatN #130887

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 45 additions & 86 deletions library/core/src/iter/sources/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::fmt;
use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
use crate::mem::{self, MaybeUninit};
use crate::num::NonZero;

/// Creates a new iterator that repeats a single element a given number of times.
Expand Down Expand Up @@ -57,78 +56,49 @@ use crate::num::NonZero;
#[inline]
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
let element = if count == 0 {
// `element` gets dropped eagerly.
MaybeUninit::uninit()
} else {
MaybeUninit::new(element)
};

RepeatN { element, count }
RepeatN { inner: RepeatNInner::new(element, count) }
}

#[derive(Clone, Copy)]
#[repr(C)] // keep the layout consistent for codegen tests
struct RepeatNInner<T> {
count: NonZero<usize>,
element: T,
}

impl<T> RepeatNInner<T> {
fn new(element: T, count: usize) -> Option<Self> {
let count = NonZero::<usize>::new(count)?;
Some(Self { element, count })
}
}

/// An iterator that repeats an element an exact number of times.
///
/// This `struct` is created by the [`repeat_n()`] function.
/// See its documentation for more.
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
#[derive(Clone)]
pub struct RepeatN<A> {
count: usize,
// Invariant: uninit iff count == 0.
element: MaybeUninit<A>,
inner: Option<RepeatNInner<A>>,
}

impl<A> RepeatN<A> {
/// Returns the element if it hasn't been dropped already.
fn element_ref(&self) -> Option<&A> {
if self.count > 0 {
// SAFETY: The count is non-zero, so it must be initialized.
Some(unsafe { self.element.assume_init_ref() })
} else {
None
}
}
/// If we haven't already dropped the element, return it in an option.
///
/// Clears the count so it won't be dropped again later.
#[inline]
fn take_element(&mut self) -> Option<A> {
if self.count > 0 {
self.count = 0;
let element = mem::replace(&mut self.element, MaybeUninit::uninit());
// SAFETY: We just set count to zero so it won't be dropped again,
// and it used to be non-zero so it hasn't already been dropped.
unsafe { Some(element.assume_init()) }
} else {
None
}
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> Clone for RepeatN<A> {
fn clone(&self) -> RepeatN<A> {
RepeatN {
count: self.count,
element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
}
self.inner.take().map(|inner| inner.element)
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RepeatN")
.field("count", &self.count)
.field("element", &self.element_ref())
.finish()
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A> Drop for RepeatN<A> {
fn drop(&mut self) {
self.take_element();
let (count, element) = match self.inner.as_ref() {
Some(inner) => (inner.count.get(), Some(&inner.element)),
None => (0, None),
};
f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
}
}

Expand All @@ -138,12 +108,17 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn next(&mut self) -> Option<A> {
if self.count > 0 {
// SAFETY: Just checked it's not empty
unsafe { Some(self.next_unchecked()) }
} else {
None
let inner = self.inner.as_mut()?;
let count = inner.count.get();

if let Some(decremented) = NonZero::<usize>::new(count - 1) {
// Order of these is important for optimization
let tmp = inner.element.clone();
inner.count = decremented;
return Some(tmp);
}

return self.take_element();
}

#[inline]
Expand All @@ -154,19 +129,19 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
let len = self.count;
let Some(inner) = self.inner.as_mut() else {
return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
};

if skip >= len {
self.take_element();
}
let len = inner.count.get();

if skip > len {
// SAFETY: we just checked that the difference is positive
Err(unsafe { NonZero::new_unchecked(skip - len) })
} else {
self.count = len - skip;
Ok(())
if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
inner.count = new_len;
return Ok(());
}

self.inner = None;
return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
}

#[inline]
Expand All @@ -183,7 +158,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> ExactSizeIterator for RepeatN<A> {
fn len(&self) -> usize {
self.count
self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
}
}

Expand Down Expand Up @@ -211,20 +186,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
#[unstable(feature = "trusted_len", issue = "37572")]
unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> UncheckedIterator for RepeatN<A> {
#[inline]
unsafe fn next_unchecked(&mut self) -> Self::Item {
// SAFETY: The caller promised the iterator isn't empty
self.count = unsafe { self.count.unchecked_sub(1) };
if self.count == 0 {
// SAFETY: the check above ensured that the count used to be non-zero,
// so element hasn't been dropped yet, and we just lowered the count to
// zero so it won't be dropped later, and thus it's okay to take it here.
unsafe { mem::replace(&mut self.element, MaybeUninit::uninit()).assume_init() }
} else {
// SAFETY: the count is non-zero, so it must have not been dropped yet.
let element = unsafe { self.element.assume_init_ref() };
A::clone(element)
}
}
}
impl<A: Clone> UncheckedIterator for RepeatN<A> {}
2 changes: 1 addition & 1 deletion tests/codegen/iter-repeat-n-trivial-drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ pub fn iter_repeat_n_next(it: &mut std::iter::RepeatN<NotCopy>) -> Option<NotCop

// CHECK: [[NOT_EMPTY]]:
// CHECK-NEXT: %[[DEC:.+]] = add i64 %[[COUNT]], -1
// CHECK-NEXT: store i64 %[[DEC]]
// CHECK-NOT: br
// CHECK: %[[VAL:.+]] = load i16
// CHECK-NEXT: store i64 %[[DEC]]
// CHECK-NEXT: br label %[[EMPTY]]

// CHECK: [[EMPTY]]:
Expand Down
Loading