diff --git a/serial_test/src/code_lock.rs b/serial_test/src/code_lock.rs index cbb2161..dfb1a13 100644 --- a/serial_test/src/code_lock.rs +++ b/serial_test/src/code_lock.rs @@ -30,7 +30,6 @@ impl UniqueReentrantMutex { self.locks.parallel_count() } - #[cfg(test)] pub fn is_locked(&self) -> bool { self.locks.is_locked() } @@ -44,6 +43,63 @@ pub(crate) fn global_locks() -> &'static HashMap { LOCKS.get_or_init(HashMap::new) } +/// Check if we are holding a serial lock +/// +/// Can be used to assert that a piece of code can only be called +/// from a test marked `#[serial]`. +/// +/// Example, with `#[serial]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// fn do_something_in_need_of_serialization() { +/// assert!(is_locked_serially(None)); +/// +/// // ... +/// } +/// +/// #[test] +/// # fn unused() {} +/// #[serial] +/// fn main() { +/// do_something_in_need_of_serialization(); +/// } +/// ``` +/// +/// Example, missing `#[serial]`: +/// +/// ```should_panic +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// // #[serial] // <-- missing +/// fn main() { +/// assert!(is_locked_serially(None)); +/// } +/// ``` +/// +/// Example, `#[test(some_key)]`: +/// +/// ``` +/// use serial_test::{is_locked_serially, serial}; +/// +/// #[test] +/// # fn unused() {} +/// #[serial(some_key)] +/// fn main() { +/// assert!(is_locked_serially(Some("some_key"))); +/// assert!(!is_locked_serially(None)); +/// } +/// ``` +pub fn is_locked_serially(name: Option<&str>) -> bool { + global_locks() + .get(name.unwrap_or_default()) + .map(|lock| lock.get().is_locked()) + .unwrap_or_default() +} + static MUTEX_ID: AtomicU32 = AtomicU32::new(1); impl UniqueReentrantMutex { diff --git a/serial_test/src/lib.rs b/serial_test/src/lib.rs index d6519f0..a5d71c0 100644 --- a/serial_test/src/lib.rs +++ b/serial_test/src/lib.rs @@ -112,3 +112,5 @@ pub use serial_test_derive::{parallel, serial}; #[cfg(feature = "file_locks")] pub use serial_test_derive::{file_parallel, file_serial}; + +pub use code_lock::is_locked_serially; diff --git a/serial_test/src/rwlock.rs b/serial_test/src/rwlock.rs index 0be0c90..95ed34a 100644 --- a/serial_test/src/rwlock.rs +++ b/serial_test/src/rwlock.rs @@ -49,7 +49,6 @@ impl Locks { } } - #[cfg(test)] pub fn is_locked(&self) -> bool { self.arc.serial.is_locked() }