diff --git a/.github/.cspell/project-dictionary.txt b/.github/.cspell/project-dictionary.txt index 8e7e4295..366828d6 100644 --- a/.github/.cspell/project-dictionary.txt +++ b/.github/.cspell/project-dictionary.txt @@ -39,6 +39,7 @@ opensbi prefetcher quadword rclass +reentrancy semihosting seqlock sifive diff --git a/.github/.cspell/rust-dependencies.txt b/.github/.cspell/rust-dependencies.txt index d952be89..9654e41b 100644 --- a/.github/.cspell/rust-dependencies.txt +++ b/.github/.cspell/rust-dependencies.txt @@ -3,10 +3,12 @@ assertions atomic +cell criterion critical crossbeam fastrand +once paste portable quickcheck diff --git a/Cargo.toml b/Cargo.toml index 5359aeaa..c792565d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,9 +62,10 @@ serde = { version = "1.0.103", optional = true, default-features = false } critical-section = { version = "1", optional = true } [dev-dependencies] -critical-section = { version = "1", features = ["std"] } +critical-section = { version = "1", features = ["restore-state-bool"] } crossbeam-utils = "0.8" fastrand = "1" +once_cell = "1" paste = "1" quickcheck = { default-features = false, git = "https://github.com/taiki-e/quickcheck.git", branch = "dev" } # https://github.com/BurntSushi/quickcheck/pull/304 + https://github.com/BurntSushi/quickcheck/pull/282 + lower MSRV serde = { version = "1", features = ["derive"] } diff --git a/src/tests/critical_section_std.rs b/src/tests/critical_section_std.rs new file mode 100644 index 00000000..c5cbfe7b --- /dev/null +++ b/src/tests/critical_section_std.rs @@ -0,0 +1,68 @@ +// Based on https://github.com/rust-embedded/critical-section/blob/v1.1.1/src/std.rs, +// but compatible with Rust 1.59 that we run test. + +use std::{ + cell::Cell, + mem::MaybeUninit, + sync::{Mutex, MutexGuard}, +}; + +use once_cell::sync::Lazy; + +static GLOBAL_MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); + +// This is initialized if a thread has acquired the CS, uninitialized otherwise. +static mut GLOBAL_GUARD: MaybeUninit> = MaybeUninit::uninit(); + +std::thread_local!(static IS_LOCKED: Cell = Cell::new(false)); + +struct StdCriticalSection; +critical_section::set_impl!(StdCriticalSection); + +unsafe impl critical_section::Impl for StdCriticalSection { + unsafe fn acquire() -> bool { + // Allow reentrancy by checking thread local state + IS_LOCKED.with(|l| { + if l.get() { + // CS already acquired in the current thread. + return true; + } + + // Note: it is fine to set this flag *before* acquiring the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + l.set(true); + + // Not acquired in the current thread, acquire it. + let guard = match GLOBAL_MUTEX.lock() { + Ok(guard) => guard, + Err(err) => { + // Ignore poison on the global mutex in case a panic occurred + // while the mutex was held. + err.into_inner() + } + }; + unsafe { + GLOBAL_GUARD.write(guard); + } + + false + }) + } + + unsafe fn release(nested_cs: bool) { + if !nested_cs { + // SAFETY: As per the acquire/release safety contract, release can only be called + // if the critical section is acquired in the current thread, + // in which case we know the GLOBAL_GUARD is initialized. + unsafe { + GLOBAL_GUARD.as_mut_ptr().drop_in_place(); + } + + // Note: it is fine to clear this flag *after* releasing the mutex because it's thread local. + // No other thread can see its value, there's no potential for races. + // This way, we hold the mutex for slightly less time. + IS_LOCKED.with(|l| l.set(false)); + } + } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 1cfe813f..2154cabc 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -9,6 +9,8 @@ #[macro_use] pub(crate) mod helper; +#[cfg(feature = "critical-section")] +mod critical_section_std; #[cfg(feature = "serde")] mod serde;