Skip to content

Commit 4ee2152

Browse files
committed
remove extra bounds checks from memo table hot-paths
1 parent dba66f1 commit 4ee2152

File tree

4 files changed

+70
-43
lines changed

4 files changed

+70
-43
lines changed

src/input.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ impl<C: Configuration> IngredientImpl<C> {
116116
fields,
117117
revisions,
118118
durabilities,
119-
memos: MemoTable::new(self.memo_table_types()),
119+
// SAFETY: We only ever access the memos of a value that we allocated through
120+
// our `MemoTableTypes`.
121+
memos: unsafe { MemoTable::new(self.memo_table_types()) },
120122
})
121123
});
122124

src/interned.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,9 @@ where
586586
let id = zalsa_local.allocate(zalsa, self.ingredient_index, |id| Value::<C> {
587587
shard: shard_index as u16,
588588
link: LinkedListLink::new(),
589-
memos: UnsafeCell::new(MemoTable::new(self.memo_table_types())),
589+
// SAFETY: We only ever access the memos of a value that we allocated through
590+
// our `MemoTableTypes`.
591+
memos: UnsafeCell::new(unsafe { MemoTable::new(self.memo_table_types()) }),
590592
// SAFETY: We call `from_internal_data` to restore the correct lifetime before access.
591593
fields: UnsafeCell::new(unsafe { self.to_internal_data(assemble(id, key)) }),
592594
shared: UnsafeCell::new(ValueShared {

src/table/memo.rs

Lines changed: 61 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ pub(crate) struct MemoTable {
1515

1616
impl MemoTable {
1717
/// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`.
18-
pub fn new(types: &MemoTableTypes) -> Self {
18+
///
19+
/// # Safety
20+
///
21+
/// The created memo table must only be accessed with the same `MemoTableTypes`.
22+
pub unsafe fn new(types: &MemoTableTypes) -> Self {
23+
// Note that the safety invariant guarantees that any indices in-bounds for
24+
// this table are also in-bounds for its `MemoTableTypes`, as `MemoTableTypes`
25+
// is append-only.
1926
Self {
2027
memos: (0..types.len()).map(|_| MemoEntry::default()).collect(),
2128
}
@@ -179,46 +186,51 @@ impl MemoTableWithTypes<'_> {
179186
memo_ingredient_index: MemoIngredientIndex,
180187
memo: NonNull<M>,
181188
) -> Option<NonNull<M>> {
182-
// The type must already exist, we insert it when creating the memo ingredient.
183-
assert_eq!(
189+
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;
190+
191+
// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
192+
// corresponding `MemoTableTypes`, by construction.
193+
let type_ = unsafe {
184194
self.types
185195
.types
186-
.get(memo_ingredient_index.as_usize())?
187-
.type_id,
188-
TypeId::of::<M>(),
189-
"inconsistent type-id for `{memo_ingredient_index:?}`"
190-
);
191-
192-
// The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`.
193-
let MemoEntry { atomic_memo } = self
194-
.memos
195-
.memos
196-
.get(memo_ingredient_index.as_usize())
197-
.expect("accessed memo table with invalid index");
196+
.get_unchecked(memo_ingredient_index.as_usize())
197+
};
198198

199-
let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);
199+
// Verify that the we are casting to the correct type.
200+
if type_.type_id != TypeId::of::<M>() {
201+
type_assert_failed(memo_ingredient_index);
202+
}
200203

201-
let old_memo = NonNull::new(old_memo);
204+
let old_memo = atomic_memo.swap(MemoEntryType::to_dummy(memo).as_ptr(), Ordering::AcqRel);
202205

203-
// SAFETY: `type_id` check asserted above
204-
old_memo.map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
206+
// SAFETY: We asserted that the type is correct above.
207+
NonNull::new(old_memo).map(|old_memo| unsafe { MemoEntryType::from_dummy(old_memo) })
205208
}
206209

210+
/// Returns a pointer to the memo at the given index, if one has been inserted.
207211
#[inline]
208212
pub(crate) fn get<M: Memo>(
209213
self,
210214
memo_ingredient_index: MemoIngredientIndex,
211215
) -> Option<NonNull<M>> {
212-
let memo = self.memos.memos.get(memo_ingredient_index.as_usize())?;
213-
let type_ = self.types.types.get(memo_ingredient_index.as_usize())?;
214-
assert_eq!(
215-
type_.type_id,
216-
TypeId::of::<M>(),
217-
"inconsistent type-id for `{memo_ingredient_index:?}`"
218-
);
219-
let memo = NonNull::new(memo.atomic_memo.load(Ordering::Acquire))?;
220-
// SAFETY: `type_id` check asserted above
221-
Some(unsafe { MemoEntryType::from_dummy(memo) })
216+
let MemoEntry { atomic_memo } = self.memos.memos.get(memo_ingredient_index.as_usize())?;
217+
218+
// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
219+
// corresponding `MemoTableTypes`, by construction.
220+
let type_ = unsafe {
221+
self.types
222+
.types
223+
.get_unchecked(memo_ingredient_index.as_usize())
224+
};
225+
226+
// Verify that the we are casting to the correct type.
227+
if type_.type_id != TypeId::of::<M>() {
228+
type_assert_failed(memo_ingredient_index);
229+
}
230+
231+
NonNull::new(atomic_memo.load(Ordering::Acquire))
232+
// SAFETY: We asserted that the type is correct above.
233+
.map(|memo| unsafe { MemoEntryType::from_dummy(memo) })
222234
}
223235

224236
#[cfg(feature = "salsa_unstable")]
@@ -256,27 +268,30 @@ impl MemoTableWithTypesMut<'_> {
256268
memo_ingredient_index: MemoIngredientIndex,
257269
f: impl FnOnce(&mut M),
258270
) {
259-
let Some(type_) = self.types.types.get(memo_ingredient_index.as_usize()) else {
260-
return;
261-
};
262-
assert_eq!(
263-
type_.type_id,
264-
TypeId::of::<M>(),
265-
"inconsistent type-id for `{memo_ingredient_index:?}`"
266-
);
267-
268-
// The memo table is pre-sized on creation based on the corresponding `MemoTableTypes`.
269271
let Some(MemoEntry { atomic_memo }) =
270272
self.memos.memos.get_mut(memo_ingredient_index.as_usize())
271273
else {
272274
return;
273275
};
274276

277+
// SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its
278+
// corresponding `MemoTableTypes`, by construction.
279+
let type_ = unsafe {
280+
self.types
281+
.types
282+
.get_unchecked(memo_ingredient_index.as_usize())
283+
};
284+
285+
// Verify that the we are casting to the correct type.
286+
if type_.type_id != TypeId::of::<M>() {
287+
type_assert_failed(memo_ingredient_index);
288+
}
289+
275290
let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else {
276291
return;
277292
};
278293

279-
// SAFETY: `type_id` check asserted above
294+
// SAFETY: We asserted that the type is correct above.
280295
f(unsafe { MemoEntryType::from_dummy(memo).as_mut() });
281296
}
282297

@@ -319,6 +334,12 @@ impl MemoTableWithTypesMut<'_> {
319334
}
320335
}
321336

337+
#[cold]
338+
#[inline(never)]
339+
fn type_assert_failed(memo_ingredient_index: MemoIngredientIndex) -> ! {
340+
panic!("inconsistent type-id for `{memo_ingredient_index:?}`")
341+
}
342+
322343
impl MemoEntry {
323344
/// # Safety
324345
///

src/tracked_struct.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,9 @@ where
443443
// lifetime erase for storage
444444
fields: unsafe { mem::transmute::<C::Fields<'db>, C::Fields<'static>>(fields) },
445445
revisions: C::new_revisions(current_deps.changed_at),
446-
memos: MemoTable::new(self.memo_table_types()),
446+
// SAFETY: We only ever access the memos of a value that we allocated through
447+
// our `MemoTableTypes`.
448+
memos: unsafe { MemoTable::new(self.memo_table_types()) },
447449
};
448450

449451
while let Some(id) = self.free_list.pop() {

0 commit comments

Comments
 (0)