Skip to content

Commit 42a2ad8

Browse files
authored
Merge pull request #6 from mkroening/soundness
fix(soundness): disable interrupts on every `RefCell` flags change
2 parents c7c7183 + 9ed8781 commit 42a2ad8

File tree

2 files changed

+130
-46
lines changed

2 files changed

+130
-46
lines changed

src/interrupt_dropper.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use core::mem::ManuallyDrop;
2+
use core::ops::{Deref, DerefMut};
3+
4+
/// A wrapper for dropping values while interrupts are disabled.
5+
pub struct InterruptDropper<T> {
6+
inner: ManuallyDrop<T>,
7+
}
8+
9+
impl<T> From<T> for InterruptDropper<T> {
10+
#[inline]
11+
fn from(value: T) -> Self {
12+
Self {
13+
inner: ManuallyDrop::new(value),
14+
}
15+
}
16+
}
17+
18+
impl<T> InterruptDropper<T> {
19+
#[inline]
20+
pub fn into_inner(mut this: Self) -> T {
21+
// SAFETY: We never use `this` after this again.
22+
unsafe { ManuallyDrop::take(&mut this.inner) }
23+
}
24+
}
25+
26+
impl<T> Deref for InterruptDropper<T> {
27+
type Target = T;
28+
29+
#[inline]
30+
fn deref(&self) -> &Self::Target {
31+
self.inner.deref()
32+
}
33+
}
34+
35+
impl<T> DerefMut for InterruptDropper<T> {
36+
#[inline]
37+
fn deref_mut(&mut self) -> &mut Self::Target {
38+
self.inner.deref_mut()
39+
}
40+
}
41+
42+
impl<T> Drop for InterruptDropper<T> {
43+
#[inline]
44+
fn drop(&mut self) {
45+
let guard = interrupts::disable();
46+
// Drop `inner` as while we can guarentee interrupts are disabled
47+
// SAFETY: This is not exposed to safe code and is not called more than once
48+
unsafe { ManuallyDrop::drop(&mut self.inner) }
49+
drop(guard);
50+
}
51+
}

src/lib.rs

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,18 @@
6262
6363
#![cfg_attr(target_os = "none", no_std)]
6464

65+
mod interrupt_dropper;
6566
#[cfg(not(target_os = "none"))]
6667
mod local_key;
68+
6769
use core::cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut};
6870
use core::cmp::Ordering;
6971
use core::ops::{Deref, DerefMut};
7072
use core::{fmt, mem};
7173

74+
use self::interrupt_dropper::InterruptDropper;
7275
#[cfg(not(target_os = "none"))]
73-
pub use local_key::LocalKeyExt;
76+
pub use self::local_key::LocalKeyExt;
7477

7578
/// A mutable memory location with dynamically checked borrow rules
7679
///
@@ -263,10 +266,11 @@ impl<T: ?Sized> InterruptRefCell<T> {
263266
#[inline]
264267
#[cfg_attr(feature = "debug_interruptrefcell", track_caller)]
265268
pub fn try_borrow(&self) -> Result<InterruptRef<'_, T>, BorrowError> {
266-
let _guard = interrupts::disable();
267-
self.inner
268-
.try_borrow()
269-
.map(|inner| InterruptRef { inner, _guard })
269+
let guard = interrupts::disable();
270+
self.inner.try_borrow().map(|inner| {
271+
let inner = InterruptDropper::from(inner);
272+
InterruptRef { inner, guard }
273+
})
270274
}
271275

272276
/// Mutably borrows the wrapped value.
@@ -333,10 +337,11 @@ impl<T: ?Sized> InterruptRefCell<T> {
333337
#[inline]
334338
#[cfg_attr(feature = "debug_interruptrefcell", track_caller)]
335339
pub fn try_borrow_mut(&self) -> Result<InterruptRefMut<'_, T>, BorrowMutError> {
336-
let _guard = interrupts::disable();
337-
self.inner
338-
.try_borrow_mut()
339-
.map(|inner| InterruptRefMut { inner, _guard })
340+
let guard = interrupts::disable();
341+
self.inner.try_borrow_mut().map(|inner| {
342+
let inner = InterruptDropper::from(inner);
343+
InterruptRefMut { inner, guard }
344+
})
340345
}
341346

342347
/// Returns a raw pointer to the underlying data in this cell.
@@ -416,7 +421,10 @@ impl<T: ?Sized> InterruptRefCell<T> {
416421
/// ```
417422
#[inline]
418423
pub unsafe fn try_borrow_unguarded(&self) -> Result<&T, BorrowError> {
419-
self.inner.try_borrow_unguarded()
424+
let guard = interrupts::disable();
425+
let ret = self.inner.try_borrow_unguarded();
426+
drop(guard);
427+
ret
420428
}
421429
}
422430

@@ -547,8 +555,8 @@ impl<T> From<T> for InterruptRefCell<T> {
547555
///
548556
/// See the [module-level documentation](self) for more.
549557
pub struct InterruptRef<'b, T: ?Sized + 'b> {
550-
inner: Ref<'b, T>,
551-
_guard: interrupts::Guard,
558+
inner: InterruptDropper<Ref<'b, T>>,
559+
guard: interrupts::Guard,
552560
}
553561

554562
impl<T: ?Sized> Deref for InterruptRef<'_, T> {
@@ -573,10 +581,9 @@ impl<'b, T: ?Sized> InterruptRef<'b, T> {
573581
#[must_use]
574582
#[inline]
575583
pub fn clone(orig: &InterruptRef<'b, T>) -> InterruptRef<'b, T> {
576-
InterruptRef {
577-
inner: Ref::clone(&orig.inner),
578-
_guard: interrupts::disable(),
579-
}
584+
let guard = interrupts::disable();
585+
let inner = InterruptDropper::from(Ref::clone(&orig.inner));
586+
InterruptRef { inner, guard }
580587
}
581588

582589
/// Makes a new `InterruptRef` for a component of the borrowed data.
@@ -602,11 +609,9 @@ impl<'b, T: ?Sized> InterruptRef<'b, T> {
602609
where
603610
F: FnOnce(&T) -> &U,
604611
{
605-
let InterruptRef { inner, _guard } = orig;
606-
InterruptRef {
607-
inner: Ref::map(inner, f),
608-
_guard,
609-
}
612+
let InterruptRef { inner, guard } = orig;
613+
let inner = InterruptDropper::from(Ref::map(InterruptDropper::into_inner(inner), f));
614+
InterruptRef { inner, guard }
610615
}
611616

612617
/// Makes a new `InterruptRef` for an optional component of the borrowed data. The
@@ -638,10 +643,24 @@ impl<'b, T: ?Sized> InterruptRef<'b, T> {
638643
where
639644
F: FnOnce(&T) -> Option<&U>,
640645
{
641-
let InterruptRef { inner, _guard } = orig;
642-
match Ref::filter_map(inner, f) {
643-
Ok(inner) => Ok(InterruptRef { inner, _guard }),
644-
Err(inner) => Err(InterruptRef { inner, _guard }),
646+
let guard = interrupts::disable();
647+
let filter_map = Ref::filter_map(InterruptDropper::into_inner(orig.inner), f);
648+
drop(guard);
649+
match filter_map {
650+
Ok(inner) => {
651+
let inner = InterruptDropper::from(inner);
652+
Ok(InterruptRef {
653+
inner,
654+
guard: orig.guard,
655+
})
656+
}
657+
Err(inner) => {
658+
let inner = InterruptDropper::from(inner);
659+
Err(InterruptRef {
660+
inner,
661+
guard: orig.guard,
662+
})
663+
}
645664
}
646665
}
647666

@@ -673,15 +692,16 @@ impl<'b, T: ?Sized> InterruptRef<'b, T> {
673692
where
674693
F: FnOnce(&T) -> (&U, &V),
675694
{
676-
let (a, b) = Ref::map_split(orig.inner, f);
695+
let guard = interrupts::disable();
696+
let (a, b) = Ref::map_split(InterruptDropper::into_inner(orig.inner), f);
677697
(
678698
InterruptRef {
679-
inner: a,
680-
_guard: orig._guard,
699+
inner: InterruptDropper::from(a),
700+
guard,
681701
},
682702
InterruptRef {
683-
inner: b,
684-
_guard: interrupts::disable(),
703+
inner: InterruptDropper::from(b),
704+
guard: orig.guard,
685705
},
686706
)
687707
}
@@ -728,11 +748,9 @@ impl<'b, T: ?Sized> InterruptRefMut<'b, T> {
728748
where
729749
F: FnOnce(&mut T) -> &mut U,
730750
{
731-
let InterruptRefMut { inner, _guard } = orig;
732-
InterruptRefMut {
733-
inner: RefMut::map(inner, f),
734-
_guard,
735-
}
751+
let InterruptRefMut { inner, guard } = orig;
752+
let inner = InterruptDropper::from(RefMut::map(InterruptDropper::into_inner(inner), f));
753+
InterruptRefMut { inner, guard }
736754
}
737755

738756
/// Makes a new `InterruptRefMut` for an optional component of the borrowed data. The
@@ -772,10 +790,24 @@ impl<'b, T: ?Sized> InterruptRefMut<'b, T> {
772790
where
773791
F: FnOnce(&mut T) -> Option<&mut U>,
774792
{
775-
let InterruptRefMut { inner, _guard } = orig;
776-
match RefMut::filter_map(inner, f) {
777-
Ok(inner) => Ok(InterruptRefMut { inner, _guard }),
778-
Err(inner) => Err(InterruptRefMut { inner, _guard }),
793+
let guard = interrupts::disable();
794+
let filter_map = RefMut::filter_map(InterruptDropper::into_inner(orig.inner), f);
795+
drop(guard);
796+
match filter_map {
797+
Ok(inner) => {
798+
let inner = InterruptDropper::from(inner);
799+
Ok(InterruptRefMut {
800+
inner,
801+
guard: orig.guard,
802+
})
803+
}
804+
Err(inner) => {
805+
let inner = InterruptDropper::from(inner);
806+
Err(InterruptRefMut {
807+
inner,
808+
guard: orig.guard,
809+
})
810+
}
779811
}
780812
}
781813

@@ -812,15 +844,16 @@ impl<'b, T: ?Sized> InterruptRefMut<'b, T> {
812844
where
813845
F: FnOnce(&mut T) -> (&mut U, &mut V),
814846
{
815-
let (a, b) = RefMut::map_split(orig.inner, f);
847+
let guard = interrupts::disable();
848+
let (a, b) = RefMut::map_split(InterruptDropper::into_inner(orig.inner), f);
816849
(
817850
InterruptRefMut {
818-
inner: a,
819-
_guard: orig._guard,
851+
inner: InterruptDropper::from(a),
852+
guard,
820853
},
821854
InterruptRefMut {
822-
inner: b,
823-
_guard: interrupts::disable(),
855+
inner: InterruptDropper::from(b),
856+
guard: orig.guard,
824857
},
825858
)
826859
}
@@ -830,8 +863,8 @@ impl<'b, T: ?Sized> InterruptRefMut<'b, T> {
830863
///
831864
/// See the [module-level documentation](self) for more.
832865
pub struct InterruptRefMut<'b, T: ?Sized + 'b> {
833-
inner: RefMut<'b, T>,
834-
_guard: interrupts::Guard,
866+
inner: InterruptDropper<RefMut<'b, T>>,
867+
guard: interrupts::Guard,
835868
}
836869

837870
impl<T: ?Sized> Deref for InterruptRefMut<'_, T> {

0 commit comments

Comments
 (0)