Skip to content

fix Zip unsoundness (again) #141076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions library/core/src/iter/adapters/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ pub struct Zip<A, B> {
// index, len and a_len are only used by the specialized version of zip
index: usize,
len: usize,
a_len: usize,
}
impl<A: Iterator, B: Iterator> Zip<A, B> {
pub(in crate::iter) fn new(a: A, b: B) -> Zip<A, B> {
Expand Down Expand Up @@ -158,7 +157,6 @@ macro_rules! zip_impl_general_defaults {
b,
index: 0, // unused
len: 0, // unused
a_len: 0, // unused
}
}

Expand Down Expand Up @@ -299,9 +297,8 @@ where
B: TrustedRandomAccess + Iterator,
{
fn new(a: A, b: B) -> Self {
let a_len = a.size();
let len = cmp::min(a_len, b.size());
Zip { a, b, index: 0, len, a_len }
let len = cmp::min(a.size(), b.size());
Zip { a, b, index: 0, len }
}

#[inline]
Expand All @@ -315,17 +312,6 @@ where
unsafe {
Some((self.a.__iterator_get_unchecked(i), self.b.__iterator_get_unchecked(i)))
}
} else if A::MAY_HAVE_SIDE_EFFECT && self.index < self.a_len {
let i = self.index;
// as above, increment before executing code that may panic
self.index += 1;
self.len += 1;
// match the base implementation's potential side effects
// SAFETY: we just checked that `i` < `self.a.len()`
unsafe {
self.a.__iterator_get_unchecked(i);
}
None
} else {
None
}
Expand Down Expand Up @@ -371,36 +357,42 @@ where
A: DoubleEndedIterator + ExactSizeIterator,
B: DoubleEndedIterator + ExactSizeIterator,
{
if A::MAY_HAVE_SIDE_EFFECT || B::MAY_HAVE_SIDE_EFFECT {
let sz_a = self.a.size();
let sz_b = self.b.size();
// Adjust a, b to equal length, make sure that only the first call
// of `next_back` does this, otherwise we will break the restriction
// on calls to `self.next_back()` after calling `get_unchecked()`.
if sz_a != sz_b {
// No effects when the iterator is exhausted, to reduce the number of
// cases the unsafe code has to handle.
// See #137255 for a case where where too many epicycles lead to unsoundness.
if self.index < self.len {
let old_len = self.len;

// since get_unchecked and the side-effecting code can execute user code
// which can panic we decrement the counter beforehand
// so that the same index won't be accessed twice, as required by TrustedRandomAccess.
// Additionally this will ensure that the side-effects code won't run a second time.
self.len -= 1;

// Adjust a, b to equal length if we're iterating backwards.
if A::MAY_HAVE_SIDE_EFFECT || B::MAY_HAVE_SIDE_EFFECT {
// note if some forward-iteration already happened then these aren't the real
// remaining lengths of the inner iterators, so we have to relate them to
// Zip's internal length-tracking.
let sz_a = self.a.size();
if A::MAY_HAVE_SIDE_EFFECT && sz_a > self.len {
for _ in 0..sz_a - self.len {
// since next_back() may panic we increment the counters beforehand
// to keep Zip's state in sync with the underlying iterator source
self.a_len -= 1;
self.a.next_back();
}
debug_assert_eq!(self.a_len, self.len);
}
let sz_b = self.b.size();
if B::MAY_HAVE_SIDE_EFFECT && sz_b > self.len {
for _ in 0..sz_b - self.len {
self.b.next_back();
// This condition can and must only be true on the first `next_back` call,
// otherwise we will break the restriction on calls to `self.next_back()`
// after calling `get_unchecked()`.
if sz_a != sz_b && (old_len == sz_a || old_len == sz_b) {
if A::MAY_HAVE_SIDE_EFFECT && sz_a > old_len {
for _ in 0..sz_a - old_len {
self.a.next_back();
}
}
if B::MAY_HAVE_SIDE_EFFECT && sz_b > old_len {
for _ in 0..sz_b - old_len {
self.b.next_back();
}
}
debug_assert_eq!(self.a.size(), self.b.size());
}
}
}
if self.index < self.len {
// since get_unchecked executes code which can panic we increment the counters beforehand
// so that the same index won't be accessed twice, as required by TrustedRandomAccess
self.len -= 1;
self.a_len -= 1;
let i = self.len;
// SAFETY: `i` is smaller than the previous value of `self.len`,
// which is also smaller than or equal to `self.a.len()` and `self.b.len()`
Expand Down
3 changes: 2 additions & 1 deletion library/coretests/tests/iter/adapters/cloned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ fn test_cloned_side_effects() {
.zip(&[1]);
for _ in iter {}
}
assert_eq!(count, 2);
// Zip documentation provides some leeway about side-effects
assert!([1, 2].iter().any(|v| *v == count));
}

#[test]
Expand Down
119 changes: 86 additions & 33 deletions library/coretests/tests/iter/adapters/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,19 @@ fn test_zip_next_back_side_effects_exhausted() {
iter.next();
iter.next();
iter.next();
iter.next();
assert_eq!(iter.next(), None);
assert_eq!(iter.next_back(), None);
assert_eq!(a, vec![1, 2, 3, 4, 6, 5]);

assert!(a.starts_with(&[1, 2, 3]));
let a_len = a.len();
// Tail-side-effects of forward-iteration are "at most one" per next().
// And for reverse iteration we don't guarantee much either.
// But we can put some bounds on the possible behaviors.
assert!(a_len <= 6);
assert!(a_len >= 3);
a.sort();
assert_eq!(a, &[1, 2, 3, 4, 5, 6][..a.len()]);

assert_eq!(b, vec![200, 300, 400]);
}

Expand All @@ -120,7 +130,8 @@ fn test_zip_cloned_sideffectful() {

for _ in xs.iter().cloned().zip(ys.iter().cloned()) {}

assert_eq!(&xs, &[1, 1, 1, 0][..]);
// Zip documentation permits either case.
assert!([&[1, 1, 1, 0], &[1, 1, 0, 0]].iter().any(|v| &xs == *v));
assert_eq!(&ys, &[1, 1][..]);

let xs = [CountClone::new(), CountClone::new()];
Expand All @@ -139,7 +150,8 @@ fn test_zip_map_sideffectful() {

for _ in xs.iter_mut().map(|x| *x += 1).zip(ys.iter_mut().map(|y| *y += 1)) {}

assert_eq!(&xs, &[1, 1, 1, 1, 1, 0]);
// Zip documentation permits either case.
assert!([&[1, 1, 1, 1, 1, 0], &[1, 1, 1, 1, 0, 0]].iter().any(|v| &xs == *v));
assert_eq!(&ys, &[1, 1, 1, 1]);

let mut xs = [0; 4];
Expand Down Expand Up @@ -168,7 +180,8 @@ fn test_zip_map_rev_sideffectful() {

{
let mut it = xs.iter_mut().map(|x| *x += 1).zip(ys.iter_mut().map(|y| *y += 1));
(&mut it).take(5).count();
// the current impl only trims the tails if the iterator isn't exhausted
(&mut it).take(3).count();
it.next_back();
}
assert_eq!(&xs, &[1, 1, 1, 1, 1, 1]);
Expand Down Expand Up @@ -211,9 +224,18 @@ fn test_zip_nth_back_side_effects_exhausted() {
iter.next();
iter.next();
iter.next();
iter.next();
assert_eq!(iter.next(), None);
assert_eq!(iter.nth_back(0), None);
assert_eq!(a, vec![1, 2, 3, 4, 6, 5]);
assert!(a.starts_with(&[1, 2, 3]));
let a_len = a.len();
// Tail-side-effects of forward-iteration are "at most one" per next().
// And for reverse iteration we don't guarantee much either.
// But we can put some bounds on the possible behaviors.
assert!(a_len <= 6);
assert!(a_len >= 3);
a.sort();
assert_eq!(a, &[1, 2, 3, 4, 5, 6][..a.len()]);

assert_eq!(b, vec![200, 300, 400]);
}

Expand All @@ -237,32 +259,6 @@ fn test_zip_trusted_random_access_composition() {
assert_eq!(z2.next().unwrap(), ((1, 1), 1));
}

#[test]
#[cfg(panic = "unwind")]
fn test_zip_trusted_random_access_next_back_drop() {
use std::panic::{AssertUnwindSafe, catch_unwind};

let mut counter = 0;

let it = [42].iter().map(|e| {
let c = counter;
counter += 1;
if c == 0 {
panic!("bomb");
}

e
});
let it2 = [(); 0].iter();
let mut zip = it.zip(it2);
catch_unwind(AssertUnwindSafe(|| {
zip.next_back();
}))
.unwrap_err();
assert!(zip.next().is_none());
assert_eq!(counter, 1);
}

#[test]
fn test_double_ended_zip() {
let xs = [1, 2, 3, 4, 5, 6];
Expand All @@ -275,6 +271,63 @@ fn test_double_ended_zip() {
assert_eq!(it.next(), None);
}

#[test]
#[cfg(panic = "unwind")]
/// Regresion test for #137255
/// A previous implementation of Zip TrustedRandomAccess specializations tried to do a lot of work
/// to preserve side-effects of equalizing the iterator lengths during backwards iteration.
/// This lead to several cases of unsoundness, twice due to being left in an inconsistent state
/// after panics.
/// The new implementation does not try as hard, but we still need panic-safety.
fn test_nested_zip_panic_safety() {
use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind};
use std::sync::atomic::{AtomicUsize, Ordering};

let mut panic = true;
// keeps track of how often element get visited, must be at most once each
let witness = [8, 9, 10, 11, 12].map(|i| (i, AtomicUsize::new(0)));
let a = witness.as_slice().iter().map(|e| {
e.1.fetch_add(1, Ordering::Relaxed);
if panic {
panic = false;
resume_unwind(Box::new(()))
}
e.0
});
// shorter than `a`, so `a` will get trimmed
let b = [1, 2, 3, 4].as_slice().iter().copied();
// shorter still, so `ab` will get trimmed.`
let c = [5, 6, 7].as_slice().iter().copied();

// This will panic during backwards trimming.
let ab = zip(a, b);
// This being Zip + TrustedRandomAccess means it will only call `next_back``
// during trimming and otherwise do calls `__iterator_get_unchecked` on `ab`.
let mut abc = zip(ab, c);

assert_eq!(abc.len(), 3);
// This will first trigger backwards trimming before it would normally obtain the
// actual element if it weren't for the panic.
// This used to corrupt the internal state of `abc`, which then lead to
// TrustedRandomAccess safety contract violations in calls to `ab`,
// which ultimately lead to UB.
catch_unwind(AssertUnwindSafe(|| abc.next_back())).ok();
// check for sane outward behavior after the panic, which indicates a sane internal state.
// Technically these outcomes are not required because a panic frees us from correctness obligations.
assert_eq!(abc.len(), 2);
assert_eq!(abc.next(), Some(((8, 1), 5)));
assert_eq!(abc.next_back(), Some(((9, 2), 6)));
for (i, (_, w)) in witness.iter().enumerate() {
let v = w.load(Ordering::Relaxed);
// required by TRA contract
assert!(v <= 1, "expected idx {i} to be visited at most once, actual: {v}");
}
// Trimming panicked and should only run once, so this one won't be visited.
// Implementation detail, but not trying to run it again is what keeps
// things simple.
assert_eq!(witness[3].1.load(Ordering::Relaxed), 0);
}

#[test]
fn test_issue_82282() {
fn overflowed_zip(arr: &[i32]) -> impl Iterator<Item = (i32, &())> {
Expand Down
Loading