diff --git a/src/liballoc/tests/vec.rs b/src/liballoc/tests/vec.rs index 5ddac673c9ff1..c0967cd374d5e 100644 --- a/src/liballoc/tests/vec.rs +++ b/src/liballoc/tests/vec.rs @@ -945,6 +945,105 @@ fn drain_filter_complex() { } } +#[test] +fn drain_filter_consumed_panic() { + use std::rc::Rc; + use std::sync::Mutex; + + struct Check { + index: usize, + drop_counts: Rc>>, + }; + + impl Drop for Check { + fn drop(&mut self) { + self.drop_counts.lock().unwrap()[self.index] += 1; + println!("drop: {}", self.index); + } + } + + let check_count = 10; + let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count])); + let mut data: Vec = (0..check_count) + .map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) }) + .collect(); + + let _ = std::panic::catch_unwind(move || { + let filter = |c: &mut Check| { + if c.index == 2 { + panic!("panic at index: {}", c.index); + } + // Verify that if the filter could panic again on another element + // that it would not cause a double panic and all elements of the + // vec would still be dropped exactly once. + if c.index == 4 { + panic!("panic at index: {}", c.index); + } + c.index < 6 + }; + let drain = data.drain_filter(filter); + + // NOTE: The DrainFilter is explictly consumed + drain.for_each(drop); + }); + + let drop_counts = drop_counts.lock().unwrap(); + assert_eq!(check_count, drop_counts.len()); + + for (index, count) in drop_counts.iter().cloned().enumerate() { + assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count); + } +} + +#[test] +fn drain_filter_unconsumed_panic() { + use std::rc::Rc; + use std::sync::Mutex; + + struct Check { + index: usize, + drop_counts: Rc>>, + }; + + impl Drop for Check { + fn drop(&mut self) { + self.drop_counts.lock().unwrap()[self.index] += 1; + println!("drop: {}", self.index); + } + } + + let check_count = 10; + let drop_counts = Rc::new(Mutex::new(vec![0_usize; check_count])); + let mut data: Vec = (0..check_count) + .map(|index| Check { index, drop_counts: Rc::clone(&drop_counts) }) + .collect(); + + let _ = std::panic::catch_unwind(move || { + let filter = |c: &mut Check| { + if c.index == 2 { + panic!("panic at index: {}", c.index); + } + // Verify that if the filter could panic again on another element + // that it would not cause a double panic and all elements of the + // vec would still be dropped exactly once. + if c.index == 4 { + panic!("panic at index: {}", c.index); + } + c.index < 6 + }; + let _drain = data.drain_filter(filter); + + // NOTE: The DrainFilter is dropped without being consumed + }); + + let drop_counts = drop_counts.lock().unwrap(); + assert_eq!(check_count, drop_counts.len()); + + for (index, count) in drop_counts.iter().cloned().enumerate() { + assert_eq!(1, count, "unexpected drop count at index: {} (count: {})", index, count); + } +} + #[test] fn test_reserve_exact() { // This is all the same as test_reserve diff --git a/src/liballoc/vec.rs b/src/liballoc/vec.rs index 5cb91395b7bf7..adc0929fdbed4 100644 --- a/src/liballoc/vec.rs +++ b/src/liballoc/vec.rs @@ -2120,6 +2120,7 @@ impl Vec { del: 0, old_len, pred: filter, + panic_flag: false, } } } @@ -2751,6 +2752,7 @@ pub struct DrainFilter<'a, T, F> del: usize, old_len: usize, pred: F, + panic_flag: bool, } #[unstable(feature = "drain_filter", reason = "recently added", issue = "43244")] @@ -2760,21 +2762,34 @@ impl Iterator for DrainFilter<'_, T, F> type Item = T; fn next(&mut self) -> Option { + struct SetIdxOnDrop<'a> { + idx: &'a mut usize, + new_idx: usize, + } + + impl<'a> Drop for SetIdxOnDrop<'a> { + fn drop(&mut self) { + *self.idx = self.new_idx; + } + } + unsafe { - while self.idx != self.old_len { + while self.idx < self.old_len { let i = self.idx; - self.idx += 1; let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len); - if (self.pred)(&mut v[i]) { + let mut set_idx = SetIdxOnDrop { new_idx: self.idx, idx: &mut self.idx }; + self.panic_flag = true; + let drained = (self.pred)(&mut v[i]); + self.panic_flag = false; + set_idx.new_idx += 1; + if drained { self.del += 1; return Some(ptr::read(&v[i])); - } else if self.del > 0 { + } + else if self.del > 0 { let del = self.del; let src: *const T = &v[i]; let dst: *mut T = &mut v[i - del]; - // This is safe because self.vec has length 0 - // thus its elements will not have Drop::drop - // called on them in the event of a panic. ptr::copy_nonoverlapping(src, dst, 1); } } @@ -2792,9 +2807,47 @@ impl Drop for DrainFilter<'_, T, F> where F: FnMut(&mut T) -> bool, { fn drop(&mut self) { - self.for_each(drop); - unsafe { - self.vec.set_len(self.old_len - self.del); + // If the predicate panics, we still need to backshift everything + // down after the last successfully drained element, but no additional + // elements are drained or checked. + struct BackshiftOnDrop<'a, 'b, T, F> + where + F: FnMut(&mut T) -> bool, + { + drain: &'b mut DrainFilter<'a, T, F>, + } + + impl<'a, 'b, T, F> Drop for BackshiftOnDrop<'a, 'b, T, F> + where + F: FnMut(&mut T) -> bool + { + fn drop(&mut self) { + unsafe { + while self.drain.idx < self.drain.old_len { + let i = self.drain.idx; + self.drain.idx += 1; + let v = slice::from_raw_parts_mut( + self.drain.vec.as_mut_ptr(), + self.drain.old_len, + ); + if self.drain.del > 0 { + let del = self.drain.del; + let src: *const T = &v[i]; + let dst: *mut T = &mut v[i - del]; + ptr::copy_nonoverlapping(src, dst, 1); + } + } + self.drain.vec.set_len(self.drain.old_len - self.drain.del); + } + } + } + + let backshift = BackshiftOnDrop { + drain: self + }; + + if !backshift.drain.panic_flag { + backshift.drain.for_each(drop); } } }