Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collections: Make BinaryHeap panic safe in sift_up / sift_down #25856

Merged
merged 1 commit into from
May 28, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 88 additions & 25 deletions src/libcollections/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
use core::prelude::*;

use core::iter::{FromIterator};
use core::mem::{zeroed, replace, swap};
use core::mem::swap;
use core::ptr;

use slice;
Expand Down Expand Up @@ -484,46 +484,42 @@ impl<T: Ord> BinaryHeap<T> {

// The implementations of sift_up and sift_down use unsafe blocks in
// order to move an element out of the vector (leaving behind a
// zeroed element), shift along the others and move it back into the
// vector over the junk element. This reduces the constant factor
// compared to using swaps, which involves twice as many moves.
fn sift_up(&mut self, start: usize, mut pos: usize) {
// hole), shift along the others and move the removed element back into the
// vector at the final location of the hole.
// The `Hole` type is used to represent this, and make sure
// the hole is filled back at the end of its scope, even on panic.
// Using a hole reduces the constant factor compared to using swaps,
// which involves twice as many moves.
fn sift_up(&mut self, start: usize, pos: usize) {
unsafe {
let new = replace(&mut self.data[pos], zeroed());
// Take out the value at `pos` and create a hole.
let mut hole = Hole::new(&mut self.data, pos);

while pos > start {
let parent = (pos - 1) >> 1;

if new <= self.data[parent] { break; }

let x = replace(&mut self.data[parent], zeroed());
ptr::write(&mut self.data[pos], x);
pos = parent;
while hole.pos() > start {
let parent = (hole.pos() - 1) / 2;
if hole.removed() <= hole.get(parent) { break }
hole.move_to(parent);
}
ptr::write(&mut self.data[pos], new);
}
}

fn sift_down_range(&mut self, mut pos: usize, end: usize) {
let start = pos;
unsafe {
let start = pos;
let new = replace(&mut self.data[pos], zeroed());

let mut hole = Hole::new(&mut self.data, pos);
let mut child = 2 * pos + 1;
while child < end {
let right = child + 1;
if right < end && !(self.data[child] > self.data[right]) {
if right < end && !(hole.get(child) > hole.get(right)) {
child = right;
}
let x = replace(&mut self.data[child], zeroed());
ptr::write(&mut self.data[pos], x);
pos = child;
child = 2 * pos + 1;
hole.move_to(child);
child = 2 * hole.pos() + 1;
}

ptr::write(&mut self.data[pos], new);
self.sift_up(start, pos);
pos = hole.pos;
}
self.sift_up(start, pos);
}

fn sift_down(&mut self, pos: usize) {
Expand Down Expand Up @@ -554,6 +550,73 @@ impl<T: Ord> BinaryHeap<T> {
pub fn clear(&mut self) { self.drain(); }
}

/// Hole represents a hole in a slice i.e. an index without valid value
/// (because it was moved from or duplicated).
/// In drop, `Hole` will restore the slice by filling the hole
/// position with the value that was originally removed.
struct Hole<'a, T: 'a> {
data: &'a mut [T],
/// `elt` is always `Some` from new until drop.
elt: Option<T>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps worth noting that this value is Some until this value is Dropped

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

pos: usize,
}

impl<'a, T> Hole<'a, T> {
/// Create a new Hole at index `pos`.
fn new(data: &'a mut [T], pos: usize) -> Self {
unsafe {
let elt = ptr::read(&data[pos]);
Hole {
data: data,
elt: Some(elt),
pos: pos,
}
}
}

#[inline(always)]
fn pos(&self) -> usize { self.pos }

/// Return a reference to the element removed
#[inline(always)]
fn removed(&self) -> &T {
self.elt.as_ref().unwrap()
}

/// Return a reference to the element at `index`.
///
/// Panics if the index is out of bounds.
///
/// Unsafe because index must not equal pos.
#[inline(always)]
unsafe fn get(&self, index: usize) -> &T {
debug_assert!(index != self.pos);
&self.data[index]
}

/// Move hole to new location
///
/// Unsafe because index must not equal pos.
#[inline(always)]
unsafe fn move_to(&mut self, index: usize) {
debug_assert!(index != self.pos);
let index_ptr: *const _ = &self.data[index];
let hole_ptr = &mut self.data[self.pos];
ptr::copy_nonoverlapping(index_ptr, hole_ptr, 1);
self.pos = index;
}
}

impl<'a, T> Drop for Hole<'a, T> {
fn drop(&mut self) {
// fill the hole again
unsafe {
let pos = self.pos;
ptr::write(&mut self.data[pos], self.elt.take().unwrap());
}
}
}

/// `BinaryHeap` iterator.
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Iter <'a, T: 'a> {
Expand Down
108 changes: 108 additions & 0 deletions src/test/run-pass/binary-heap-panic-safe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

#![feature(std_misc, collections, catch_panic, rand)]

use std::__rand::{thread_rng, Rng};
use std::thread;

use std::collections::BinaryHeap;
use std::cmp;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT, Ordering};

static DROP_COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;

// old binaryheap failed this test
//
// Integrity means that all elements are present after a comparison panics,
// even if the order may not be correct.
//
// Destructors must be called exactly once per element.
fn test_integrity() {
#[derive(Eq, PartialEq, Ord, Clone, Debug)]
struct PanicOrd<T>(T, bool);

impl<T> Drop for PanicOrd<T> {
fn drop(&mut self) {
// update global drop count
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
}
}

impl<T: PartialOrd> PartialOrd for PanicOrd<T> {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
if self.1 || other.1 {
panic!("Panicking comparison");
}
self.0.partial_cmp(&other.0)
}
}
let mut rng = thread_rng();
const DATASZ: usize = 32;
const NTEST: usize = 10;

// don't use 0 in the data -- we want to catch the zeroed-out case.
let data = (1..DATASZ + 1).collect::<Vec<_>>();

// since it's a fuzzy test, run several tries.
for _ in 0..NTEST {
for i in 1..DATASZ + 1 {
DROP_COUNTER.store(0, Ordering::SeqCst);

let mut panic_ords: Vec<_> = data.iter()
.filter(|&&x| x != i)
.map(|&x| PanicOrd(x, false))
.collect();
let panic_item = PanicOrd(i, true);

// heapify the sane items
rng.shuffle(&mut panic_ords);
let heap = Arc::new(Mutex::new(BinaryHeap::from_vec(panic_ords)));
let inner_data;

{
let heap_ref = heap.clone();


// push the panicking item to the heap and catch the panic
let thread_result = thread::catch_panic(move || {
heap.lock().unwrap().push(panic_item);
});
assert!(thread_result.is_err());

// Assert no elements were dropped
let drops = DROP_COUNTER.load(Ordering::SeqCst);
//assert!(drops == 0, "Must not drop items. drops={}", drops);

{
// now fetch the binary heap's data vector
let mutex_guard = match heap_ref.lock() {
Ok(x) => x,
Err(poison) => poison.into_inner(),
};
inner_data = mutex_guard.clone().into_vec();
}
}
let drops = DROP_COUNTER.load(Ordering::SeqCst);
assert_eq!(drops, DATASZ);

let mut data_sorted = inner_data.into_iter().map(|p| p.0).collect::<Vec<_>>();
data_sorted.sort();
assert_eq!(data_sorted, data);
}
}
}

fn main() {
test_integrity();
}