Skip to content

Commit

Permalink
drop overwritten component data on double insert (#2227)
Browse files Browse the repository at this point in the history
Continuing the work on reducing the safety footguns in the code, I've removed one extra `UnsafeCell` in favour of safe `Cell` usage inisde `ComponentTicks`. That change led to discovery of misbehaving component insert logic, where data wasn't properly dropped when overwritten. Apart from that being fixed, some method names were changed to better convey the "initialize new allocation" and "replace existing allocation" semantic.

Depends on #2221, I will rebase this PR after the dependency is merged. For now, review just the last commit.

Co-authored-by: Carter Anderson <mcanders1@gmail.com>
  • Loading branch information
Frizi and cart committed May 30, 2021
1 parent 173bb48 commit 1214dda
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 151 deletions.
14 changes: 7 additions & 7 deletions crates/bevy_ecs/src/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,20 @@ impl BundleInfo {
// bundle_info.component_ids are also in "bundle order"
let mut bundle_component = 0;
bundle.get_components(|component_ptr| {
// SAFE: component_id was initialized by get_dynamic_bundle_info
let component_id = *self.component_ids.get_unchecked(bundle_component);
let component_status = bundle_status.get_unchecked(bundle_component);
match self.storage_types[bundle_component] {
StorageType::Table => {
let column = table.get_column_mut(component_id).unwrap();
column.set_data_unchecked(table_row, component_ptr);
let column_status = column.get_ticks_unchecked_mut(table_row);
match component_status {
match bundle_status.get_unchecked(bundle_component) {
ComponentStatus::Added => {
*column_status = ComponentTicks::new(change_tick);
column.initialize(
table_row,
component_ptr,
ComponentTicks::new(change_tick),
);
}
ComponentStatus::Mutated => {
column_status.set_changed(change_tick);
column.replace(table_row, component_ptr, change_tick);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_ecs/src/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl Components {
}
}

#[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct ComponentTicks {
pub(crate) added: u32,
pub(crate) changed: u32,
Expand Down
52 changes: 51 additions & 1 deletion crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,34 @@ mod tests {
};
use bevy_tasks::TaskPool;
use parking_lot::Mutex;
use std::{any::TypeId, sync::Arc};
use std::{
any::TypeId,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

#[derive(Debug, PartialEq, Eq)]
struct A(usize);
struct B(usize);
struct C;

#[derive(Clone, Debug)]
struct DropCk(Arc<AtomicUsize>);
impl DropCk {
fn new_pair() -> (Self, Arc<AtomicUsize>) {
let atomic = Arc::new(AtomicUsize::new(0));
(DropCk(atomic.clone()), atomic)
}
}

impl Drop for DropCk {
fn drop(&mut self) {
self.0.as_ref().fetch_add(1, Ordering::Relaxed);
}
}

#[test]
fn random_access() {
let mut world = World::new();
Expand Down Expand Up @@ -1176,4 +1197,33 @@ mod tests {
});
assert_eq!(*world.get_resource::<i32>().unwrap(), 1);
}

#[test]
fn insert_overwrite_drop() {
let (dropck1, dropped1) = DropCk::new_pair();
let (dropck2, dropped2) = DropCk::new_pair();
let mut world = World::default();
world.spawn().insert(dropck1).insert(dropck2);
assert_eq!(dropped1.load(Ordering::Relaxed), 1);
assert_eq!(dropped2.load(Ordering::Relaxed), 0);
drop(world);
assert_eq!(dropped1.load(Ordering::Relaxed), 1);
assert_eq!(dropped2.load(Ordering::Relaxed), 1);
}

#[test]
fn insert_overwrite_drop_sparse() {
let (dropck1, dropped1) = DropCk::new_pair();
let (dropck2, dropped2) = DropCk::new_pair();
let mut world = World::default();
world
.register_component(ComponentDescriptor::new::<DropCk>(StorageType::SparseSet))
.unwrap();
world.spawn().insert(dropck1).insert(dropck2);
assert_eq!(dropped1.load(Ordering::Relaxed), 1);
assert_eq!(dropped2.load(Ordering::Relaxed), 0);
drop(world);
assert_eq!(dropped1.load(Ordering::Relaxed), 1);
assert_eq!(dropped2.load(Ordering::Relaxed), 1);
}
}
32 changes: 16 additions & 16 deletions crates/bevy_ecs/src/query/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
};
use bevy_ecs_macros::all_tuples;
use std::{
cell::UnsafeCell,
marker::PhantomData,
ptr::{self, NonNull},
};
Expand Down Expand Up @@ -343,7 +344,7 @@ impl<'w, T: Component> Fetch<'w> for ReadFetch<T> {
let column = tables[archetype.table_id()]
.get_column(state.component_id)
.unwrap();
self.table_components = column.get_ptr().cast::<T>();
self.table_components = column.get_data_ptr().cast::<T>();
}
StorageType::SparseSet => self.entities = archetype.entities().as_ptr(),
}
Expand All @@ -354,7 +355,7 @@ impl<'w, T: Component> Fetch<'w> for ReadFetch<T> {
self.table_components = table
.get_column(state.component_id)
.unwrap()
.get_ptr()
.get_data_ptr()
.cast::<T>();
}

Expand Down Expand Up @@ -387,7 +388,7 @@ impl<T: Component> WorldQuery for &mut T {
pub struct WriteFetch<T> {
storage_type: StorageType,
table_components: NonNull<T>,
table_ticks: *mut ComponentTicks,
table_ticks: *const UnsafeCell<ComponentTicks>,
entities: *const Entity,
entity_table_rows: *const usize,
sparse_set: *const ComponentSparseSet,
Expand Down Expand Up @@ -482,7 +483,7 @@ impl<'w, T: Component> Fetch<'w> for WriteFetch<T> {
entities: ptr::null::<Entity>(),
entity_table_rows: ptr::null::<usize>(),
sparse_set: ptr::null::<ComponentSparseSet>(),
table_ticks: ptr::null_mut::<ComponentTicks>(),
table_ticks: ptr::null::<UnsafeCell<ComponentTicks>>(),
last_change_tick,
change_tick,
};
Expand All @@ -509,8 +510,8 @@ impl<'w, T: Component> Fetch<'w> for WriteFetch<T> {
let column = tables[archetype.table_id()]
.get_column(state.component_id)
.unwrap();
self.table_components = column.get_ptr().cast::<T>();
self.table_ticks = column.get_ticks_mut_ptr();
self.table_components = column.get_data_ptr().cast::<T>();
self.table_ticks = column.get_ticks_ptr();
}
StorageType::SparseSet => self.entities = archetype.entities().as_ptr(),
}
Expand All @@ -519,8 +520,8 @@ impl<'w, T: Component> Fetch<'w> for WriteFetch<T> {
#[inline]
unsafe fn set_table(&mut self, state: &Self::State, table: &Table) {
let column = table.get_column(state.component_id).unwrap();
self.table_components = column.get_ptr().cast::<T>();
self.table_ticks = column.get_ticks_mut_ptr();
self.table_components = column.get_data_ptr().cast::<T>();
self.table_ticks = column.get_ticks_ptr();
}

#[inline]
Expand All @@ -531,7 +532,7 @@ impl<'w, T: Component> Fetch<'w> for WriteFetch<T> {
Mut {
value: &mut *self.table_components.as_ptr().add(table_row),
ticks: Ticks {
component_ticks: &mut *self.table_ticks.add(table_row),
component_ticks: &mut *(&*self.table_ticks.add(table_row)).get(),
change_tick: self.change_tick,
last_change_tick: self.last_change_tick,
},
Expand All @@ -558,7 +559,7 @@ impl<'w, T: Component> Fetch<'w> for WriteFetch<T> {
Mut {
value: &mut *self.table_components.as_ptr().add(table_row),
ticks: Ticks {
component_ticks: &mut *self.table_ticks.add(table_row),
component_ticks: &mut *(&*self.table_ticks.add(table_row)).get(),
change_tick: self.change_tick,
last_change_tick: self.last_change_tick,
},
Expand Down Expand Up @@ -860,7 +861,7 @@ impl<'w, T: Component> Fetch<'w> for ChangeTrackersFetch<T> {
let column = tables[archetype.table_id()]
.get_column(state.component_id)
.unwrap();
self.table_ticks = column.get_ticks_mut_ptr().cast::<ComponentTicks>();
self.table_ticks = column.get_ticks_const_ptr();
}
StorageType::SparseSet => self.entities = archetype.entities().as_ptr(),
}
Expand All @@ -871,8 +872,7 @@ impl<'w, T: Component> Fetch<'w> for ChangeTrackersFetch<T> {
self.table_ticks = table
.get_column(state.component_id)
.unwrap()
.get_ticks_mut_ptr()
.cast::<ComponentTicks>();
.get_ticks_const_ptr();
}

#[inline]
Expand All @@ -881,7 +881,7 @@ impl<'w, T: Component> Fetch<'w> for ChangeTrackersFetch<T> {
StorageType::Table => {
let table_row = *self.entity_table_rows.add(archetype_index);
ChangeTrackers {
component_ticks: *self.table_ticks.add(table_row),
component_ticks: (&*self.table_ticks.add(table_row)).clone(),
marker: PhantomData,
last_change_tick: self.last_change_tick,
change_tick: self.change_tick,
Expand All @@ -890,7 +890,7 @@ impl<'w, T: Component> Fetch<'w> for ChangeTrackersFetch<T> {
StorageType::SparseSet => {
let entity = *self.entities.add(archetype_index);
ChangeTrackers {
component_ticks: *(*self.sparse_set).get_ticks(entity).unwrap(),
component_ticks: (&*self.sparse_set).get_ticks(entity).cloned().unwrap(),
marker: PhantomData,
last_change_tick: self.last_change_tick,
change_tick: self.change_tick,
Expand All @@ -902,7 +902,7 @@ impl<'w, T: Component> Fetch<'w> for ChangeTrackersFetch<T> {
#[inline]
unsafe fn table_fetch(&mut self, table_row: usize) -> Self::Item {
ChangeTrackers {
component_ticks: *self.table_ticks.add(table_row),
component_ticks: (&*self.table_ticks.add(table_row)).clone(),
marker: PhantomData,
last_change_tick: self.last_change_tick,
change_tick: self.change_tick,
Expand Down
16 changes: 8 additions & 8 deletions crates/bevy_ecs/src/query/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
world::World,
};
use bevy_ecs_macros::all_tuples;
use std::{marker::PhantomData, ptr};
use std::{cell::UnsafeCell, marker::PhantomData, ptr};

// TODO: uncomment this and use as shorthand (remove where F::Fetch: FilterFetch everywhere) when
// this bug is fixed in Rust 1.51: https://github.com/rust-lang/rust/pull/81671
Expand Down Expand Up @@ -561,7 +561,7 @@ macro_rules! impl_tick_filter {
$(#[$fetch_meta])*
pub struct $fetch_name<T> {
storage_type: StorageType,
table_ticks: *mut ComponentTicks,
table_ticks: *const UnsafeCell<ComponentTicks>,
entity_table_rows: *const usize,
marker: PhantomData<T>,
entities: *const Entity,
Expand Down Expand Up @@ -630,7 +630,7 @@ macro_rules! impl_tick_filter {
unsafe fn init(world: &World, state: &Self::State, last_change_tick: u32, change_tick: u32) -> Self {
let mut value = Self {
storage_type: state.storage_type,
table_ticks: ptr::null_mut::<ComponentTicks>(),
table_ticks: ptr::null::<UnsafeCell<ComponentTicks>>(),
entities: ptr::null::<Entity>(),
entity_table_rows: ptr::null::<usize>(),
sparse_set: ptr::null::<ComponentSparseSet>(),
Expand All @@ -655,7 +655,7 @@ macro_rules! impl_tick_filter {
unsafe fn set_table(&mut self, state: &Self::State, table: &Table) {
self.table_ticks = table
.get_column(state.component_id).unwrap()
.get_ticks_mut_ptr();
.get_ticks_ptr();
}

unsafe fn set_archetype(&mut self, state: &Self::State, archetype: &Archetype, tables: &Tables) {
Expand All @@ -665,25 +665,25 @@ macro_rules! impl_tick_filter {
let table = &tables[archetype.table_id()];
self.table_ticks = table
.get_column(state.component_id).unwrap()
.get_ticks_mut_ptr();
.get_ticks_ptr();
}
StorageType::SparseSet => self.entities = archetype.entities().as_ptr(),
}
}

unsafe fn table_fetch(&mut self, table_row: usize) -> bool {
$is_detected(&*self.table_ticks.add(table_row), self.last_change_tick, self.change_tick)
$is_detected(&*(&*self.table_ticks.add(table_row)).get(), self.last_change_tick, self.change_tick)
}

unsafe fn archetype_fetch(&mut self, archetype_index: usize) -> bool {
match self.storage_type {
StorageType::Table => {
let table_row = *self.entity_table_rows.add(archetype_index);
$is_detected(&*self.table_ticks.add(table_row), self.last_change_tick, self.change_tick)
$is_detected(&*(&*self.table_ticks.add(table_row)).get(), self.last_change_tick, self.change_tick)
}
StorageType::SparseSet => {
let entity = *self.entities.add(archetype_index);
let ticks = (*(*self.sparse_set).get_ticks(entity).unwrap());
let ticks = (&*self.sparse_set).get_ticks(entity).cloned().unwrap();
$is_detected(&ticks, self.last_change_tick, self.change_tick)
}
}
Expand Down
18 changes: 14 additions & 4 deletions crates/bevy_ecs/src/storage/blob_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,25 @@ impl BlobVec {
}

/// # Safety
/// `index` must be in bounds
/// Allows aliased mutable access to `index`'s data. Caller must ensure this does not happen
/// - index must be in bounds
/// - memory must be reserved and uninitialized
#[inline]
pub unsafe fn set_unchecked(&self, index: usize, value: *mut u8) {
pub unsafe fn initialize_unchecked(&mut self, index: usize, value: *mut u8) {
debug_assert!(index < self.len());
let ptr = self.get_unchecked(index);
std::ptr::copy_nonoverlapping(value, ptr, self.item_layout.size());
}

/// # Safety
/// - index must be in-bounds
// - memory must be previously initialized
pub unsafe fn replace_unchecked(&mut self, index: usize, value: *mut u8) {
debug_assert!(index < self.len());
let ptr = self.get_unchecked(index);
(self.drop)(ptr);
std::ptr::copy_nonoverlapping(value, ptr, self.item_layout.size());
}

/// increases the length by one (and grows the vec if needed) with uninitialized memory and
/// returns the index
///
Expand Down Expand Up @@ -267,7 +277,7 @@ mod tests {
/// `blob_vec` must have a layout that matches Layout::new::<T>()
unsafe fn push<T>(blob_vec: &mut BlobVec, mut value: T) {
let index = blob_vec.push_uninit();
blob_vec.set_unchecked(index, (&mut value as *mut T).cast::<u8>());
blob_vec.initialize_unchecked(index, (&mut value as *mut T).cast::<u8>());
std::mem::forget(value);
}

Expand Down
Loading

0 comments on commit 1214dda

Please sign in to comment.