Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ impl<C: Configuration> IngredientImpl<C> {
fields,
revisions,
durabilities,
memos: MemoTable::new(self.memo_table_types()),
// SAFETY: We only ever access the memos of a value that we allocated through
// our `MemoTableTypes`.
memos: unsafe { MemoTable::new(self.memo_table_types()) },
})
});

Expand Down
4 changes: 3 additions & 1 deletion src/interned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ where
let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value::<C> {
shard: shard_index as u16,
link: LinkedListLink::new(),
memos: UnsafeCell::new(MemoTable::new(self.memo_table_types())),
// SAFETY: We only ever access the memos of a value that we allocated through
// our `MemoTableTypes`.
memos: UnsafeCell::new(unsafe { MemoTable::new(self.memo_table_types()) }),
// SAFETY: We call `from_internal_data` to restore the correct lifetime before access.
fields: UnsafeCell::new(unsafe { self.to_internal_data(assemble(id, key)) }),
shared: UnsafeCell::new(ValueShared {
Expand Down
102 changes: 62 additions & 40 deletions src/table/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ pub(crate) struct MemoTable {

impl MemoTable {
/// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`.
pub fn new(types: &MemoTableTypes) -> Self {
///
/// # Safety
///
/// The created memo table must only be accessed with the same `MemoTableTypes`.
pub unsafe fn new(types: &MemoTableTypes) -> Self {
// Note that the safety invariant guarantees that any indices in-bounds for
// this table are also in-bounds for its `MemoTableTypes`, as `MemoTableTypes`
// is append-only.
Self {
memos: (0..types.len()).map(|_| MemoEntry::default()).collect(),
}
Expand Down Expand Up @@ -179,46 +186,51 @@ impl MemoTableWithTypes<'_> {
memo_ingredient_index: MemoIngredientIndex,
memo: NonNull<M>,
) -> Option<NonNull<M>> {
// The type must already exist, we insert it when creating the memo ingredient.
assert_eq!(
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;

// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
// corresponding `MemoTableTypes`, by construction.
let type_ = unsafe {
self.types
.types
.get(memo_ingredient_index.as_usize())?
.type_id,
TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`"
);

// The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`.
let MemoEntry { atomic_memo } = self
.memos
.memos
.get(memo_ingredient_index.as_usize())
.expect("accessed memo table with invalid index");
.get_unchecked(memo_ingredient_index.as_usize())
};

let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);
// Verify that the we are casting to the correct type.
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}

let old_memo = NonNull::new(old_memo);
let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);

// SAFETY: `type_id` check asserted above
old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
// SAFETY: We asserted that the type is correct above.
NonNull::new(old_memo).map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
}

/// Returns a pointer to the memo at the given index, if one has been inserted.
#[inline]
pub(crate) fn get<M: Memo>(
self,
memo_ingredient_index: MemoIngredientIndex,
) -> Option<NonNull<M>> {
let memo = self.memos.memos.get(memo_ingredient_index.as_usize())?;
let type_ = self.types.types.get(memo_ingredient_index.as_usize())?;
assert_eq!(
type_.type_id,
TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`"
);
let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire))?;
// SAFETY: `type_id` check asserted above
Some(unsafe { MemoEntryType::from_dummy(memo) })
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;

// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
// corresponding `MemoTableTypes`, by construction.
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};

// Verify that the we are casting to the correct type.
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}

NonNull::new(atomic_memo.load(Ordering::Acquire))
// SAFETY: We asserted that the type is correct above.
.map(|memo| unsafe { MemoEntryType::from_dummy(memo) })
}

#[cfg(feature = "salsa_unstable")]
Expand Down Expand Up @@ -256,27 +268,30 @@ impl MemoTableWithTypesMut<'_> {
memo_ingredient_index: MemoIngredientIndex,
f: impl FnOnce(&mut M),
) {
let Some(type_) = self.types.types.get(memo_ingredient_index.as_usize()) else {
return;
};
assert_eq!(
type_.type_id,
TypeId::of::<M>(),
"inconsistent type-id for `{memo_ingredient_index:?}`"
);

// The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`.
let Some(MemoEntry { atomic_memo }) =
self.memos.memos.get_mut(memo_ingredient_index.as_usize())
else {
return;
};

// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
// corresponding `MemoTableTypes`, by construction.
let type_ = unsafe {
self.types
.types
.get_unchecked(memo_ingredient_index.as_usize())
};

// Verify that the we are casting to the correct type.
if type_.type_id != TypeId::of::<M>() {
type_assert_failed(memo_ingredient_index);
}

let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else {
return;
};

// SAFETY: `type_id` check asserted above
// SAFETY: We asserted that the type is correct above.
f(unsafe { MemoEntryType::from_dummy(memo).as_mut() });
}

Expand Down Expand Up @@ -319,6 +334,13 @@ impl MemoTableWithTypesMut<'_> {
}
}

/// This function is explicitly outlined to avoid debug machinery in the hot-path.
#[cold]
#[inline(never)]
fn type_assert_failed(memo_ingredient_index: MemoIngredientIndex) -> ! {
panic!("inconsistent type-id for `{memo_ingredient_index:?}`")
}

impl MemoEntry {
/// # Safety
///
Expand Down
4 changes: 3 additions & 1 deletion src/tracked_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ where
// lifetime erase for storage
fields: unsafe { mem::transmute::<C::Fields<'db>, C::Fields<'static>>(fields) },
revisions: C::new_revisions(current_deps.changed_at),
memos: MemoTable::new(self.memo_table_types()),
// SAFETY: We only ever access the memos of a value that we allocated through
// our `MemoTableTypes`.
memos: unsafe { MemoTable::new(self.memo_table_types()) },
};

while let Some(id) = self.free_list.pop() {
Expand Down