From 7c6a4fc850d044ba557d541f194b15b2ecbe2ae4 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 2 Oct 2023 17:11:06 +0200 Subject: [PATCH] fix deadlock and add more tests --- sqld/src/connection/libsql.rs | 279 ++++++++++++++++++---------------- 1 file changed, 146 insertions(+), 133 deletions(-) diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index f2dc0d1f..29abf989 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -265,10 +265,13 @@ impl Default for TxnState { unsafe extern "C" fn busy_handler(state: *mut c_void, _retries: c_int) -> c_int { let state = &*(state as *mut TxnState); let lock = state.slot.read(); - // fast path - if lock.is_none() { - return 1; - } + // we take a reference to the slot we will attempt to steal. this is to make sure that we + // actually steal the correct lock. + let slot = match &*lock { + Some(slot) => slot.clone(), + // fast path: there is no slot, try to acquire the lock again + None => return 1, + }; tokio::runtime::Handle::current().block_on(async move { let timeout = { @@ -279,20 +282,28 @@ unsafe extern "C" fn busy_handler(state: *mut c_void, _retries: c_in }; tokio::select! { + // The connection has notified us that it's txn has terminated, try to acquire again _ = state.notify.notified() => 1, + // the current holder of the transaction has timedout, we will attempt to steal their + // lock. _ = timeout => { - // attempt to steal the lock - let mut lock = state.slot.write(); - // we attempt to take the slot, and steal the transaction from the other - // connection - if let Some(slot) = lock.take() { - if Instant::now() >= slot.timeout_at { - tracing::info!("stole transaction lock"); + // only a single connection gets to steal the lock, others retry + if let Some(mut lock) = state.slot.try_write() { + // We check that slot wasn't already stolen, and that their is still a slot. + // The ordering is relaxed because the atomic is only set under the slot lock. + if slot.is_stolen.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed).is_ok() { + // The connection holding the current txn will sets itsef as stolen when it + // detects a timeout, so if we arrive to this point, then there is + // necessarily a slot, and this slot has to be the one we attempted to + // steal. + assert!(lock.take().is_some()); + let conn = slot.conn.lock(); // we have a lock on the connection, we don't need mode than a // Relaxed store. - slot.is_stolen.store(true, std::sync::atomic::Ordering::Relaxed); conn.rollback(); + + tracing::info!("stole transaction lock"); } } 1 @@ -373,6 +384,8 @@ impl Connection { if let Some(slot) = &lock.slot { if slot.is_stolen.load(Ordering::Relaxed) || Instant::now() > slot.timeout_at { + // we mark ourselves as stolen to notify any waiting lock thief. + slot.is_stolen.store(true, Ordering::Relaxed); lock.rollback(); has_timeout = true; } @@ -419,7 +432,12 @@ impl Connection { builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?; - let state = if matches!(this.lock().conn.transaction_state(Some(DatabaseName::Main))?, Tx::Read | Tx::Write) { + let state = if matches!( + this.lock() + .conn + .transaction_state(Some(DatabaseName::Main))?, + Tx::Read | Tx::Write + ) { State::Txn } else { State::Init @@ -697,13 +715,11 @@ where #[cfg(test)] mod test { - use insta::assert_json_snapshot; use itertools::Itertools; use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS; use tempfile::tempdir; use tokio::task::JoinSet; - use crate::connection::Connection as _; use crate::query_result_builder::test::{test_driver, TestBuilder}; use crate::query_result_builder::QueryResultBuilder; use crate::DEFAULT_AUTO_CHECKPOINT; @@ -740,7 +756,7 @@ mod test { } #[tokio::test] - async fn txn_stealing() { + async fn txn_timeout_no_stealing() { let tmp = tempdir().unwrap(); let make_conn = MakeLibSqlConn::new( tmp.path().into(), @@ -757,122 +773,75 @@ mod test { .await .unwrap(); - let conn1 = make_conn.make_connection().await.unwrap(); - let conn2 = make_conn.make_connection().await.unwrap(); - - let mut join_set = JoinSet::new(); - let notify = Arc::new(Notify::new()); - - join_set.spawn({ - let notify = notify.clone(); - async move { - // 1. take an exclusive lock - let conn = conn1.inner.clone(); - let res = tokio::task::spawn_blocking(|| { - Connection::run( - conn, - Program::seq(&["BEGIN EXCLUSIVE"]), - TestBuilder::default(), - ) - .unwrap() - }) - .await - .unwrap(); - assert!(res.0.into_ret().into_iter().all(|x| x.is_ok())); - assert_eq!(res.1, State::Txn); - assert!(conn1.inner.lock().slot.is_some()); - // 2. notify other conn that lock was acquired - notify.notify_one(); - // 6. wait till other connection steals the lock - notify.notified().await; - // 7. get an error because txn timedout - let conn = conn1.inner.clone(); - // our lock was stolen - assert!(conn1 - .inner - .lock() - .slot - .as_ref() - .unwrap() - .is_stolen - .load(Ordering::Relaxed)); - let res = tokio::task::spawn_blocking(|| { - Connection::run( - conn, - Program::seq(&["CREATE TABLE TEST (x)"]), - TestBuilder::default(), - ) - .unwrap() - }) - .await - .unwrap(); + tokio::time::pause(); + let conn = make_conn.make_connection().await.unwrap(); + let (_builder, state) = Connection::run( + conn.inner.clone(), + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); - assert!(matches!(res.0.into_ret()[0], Err(Error::LibSqlTxTimeout))); + tokio::time::advance(TXN_TIMEOUT * 2).await; - let before = Instant::now(); - let conn = conn1.inner.clone(); - // 8. try to acquire lock again - let res = tokio::task::spawn_blocking(|| { - Connection::run( - conn, - Program::seq(&["CREATE TABLE TEST (x)"]), - TestBuilder::default(), - ) - .unwrap() - }) - .await - .unwrap(); + let (builder, state) = Connection::run( + conn.inner.clone(), + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Init); + assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); + } - assert!(res.0.into_ret().into_iter().all(|x| x.is_ok())); - // the lock must have been released before the timeout - assert!(before.elapsed() < TXN_TIMEOUT); - notify.notify_one(); - } - }); + #[tokio::test] + /// A bunch of txn try to acquire the lock, and never release it. They will try to steal the + /// lock one after the other. All txn should eventually acquire the write lock + async fn serialized_txn_timeouts() { + let tmp = tempdir().unwrap(); + let make_conn = MakeLibSqlConn::new( + tmp.path().into(), + &TRANSPARENT_METHODS, + || (), + Default::default(), + Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()), + Arc::new([]), + 100000000, + 100000000, + DEFAULT_AUTO_CHECKPOINT, + watch::channel(None).1, + ) + .await + .unwrap(); - join_set.spawn({ - let notify = notify.clone(); - async move { - // 3. wait for other connection to acquire lock - notify.notified().await; - // 4. try to acquire lock as well - let conn = conn2.inner.clone(); - tokio::task::spawn_blocking(|| { - Connection::run( - conn, - Program::seq(&["BEGIN EXCLUSIVE"]), - TestBuilder::default(), - ) - .unwrap(); - }) - .await + let mut set = JoinSet::new(); + for _ in 0..10 { + let conn = make_conn.make_connection().await.unwrap(); + set.spawn_blocking(move || { + let (builder, state) = Connection::run( + conn.inner, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) .unwrap(); - // 5. notify other that we could acquire the lock - notify.notify_one(); - - // 9. rollback before timeout - tokio::time::sleep(TXN_TIMEOUT / 2).await; - let conn = conn2.inner.clone(); - let slot = conn2.inner.lock().slot.as_ref().unwrap().clone(); - tokio::task::spawn_blocking(|| { - Connection::run(conn, Program::seq(&["ROLLBACK"]), TestBuilder::default()) - .unwrap(); - }) - .await - .unwrap(); - // rolling back caused to slot to b removed - assert!(conn2.inner.lock().slot.is_none()); - // the lock was *not* stolen - notify.notified().await; - assert!(!slot.is_stolen.load(Ordering::Relaxed)); - } - }); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + }); + } - while join_set.join_next().await.is_some() {} + tokio::time::pause(); + + while let Some(ret) = set.join_next().await { + assert!(ret.is_ok()); + // advance time by a bit more than the txn timeout + tokio::time::advance(TXN_TIMEOUT + Duration::from_millis(100)).await; + } } #[tokio::test] - async fn txn_timeout_no_stealing() { + /// verify that releasing a txn before the timeout + async fn release_before_timeout() { let tmp = tempdir().unwrap(); let make_conn = MakeLibSqlConn::new( tmp.path().into(), @@ -886,18 +855,62 @@ mod test { DEFAULT_AUTO_CHECKPOINT, watch::channel(None).1, ) - .await - .unwrap(); + .await + .unwrap(); - tokio::time::pause(); - let conn = make_conn.make_connection().await.unwrap(); - let (_builder, state) = Connection::run(conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default()).unwrap(); - assert_eq!(state, State::Txn); + let conn1 = make_conn.make_connection().await.unwrap(); + tokio::task::spawn_blocking({ + let conn = conn1.inner.clone(); + move || { + let (builder, state) = Connection::run( + conn, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + } + }) + .await + .unwrap(); - tokio::time::advance(TXN_TIMEOUT * 2).await; + let conn2 = make_conn.make_connection().await.unwrap(); + let handle = tokio::task::spawn_blocking({ + let conn = conn2.inner.clone(); + move || { + let before = Instant::now(); + let (builder, state) = Connection::run( + conn, + Program::seq(&["BEGIN IMMEDIATE"]), + TestBuilder::default(), + ) + .unwrap(); + assert_eq!(state, State::Txn); + assert!(builder.into_ret()[0].is_ok()); + before.elapsed() + } + }); - let (builder, state) = Connection::run(conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default()).unwrap(); - assert_eq!(state, State::Init); - assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout))); + let wait_time = TXN_TIMEOUT / 10; + tokio::time::sleep(wait_time).await; + + tokio::task::spawn_blocking({ + let conn = conn1.inner.clone(); + move || { + let (builder, state) = + Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default()) + .unwrap(); + assert_eq!(state, State::Init); + assert!(builder.into_ret()[0].is_ok()); + } + }) + .await + .unwrap(); + + let elapsed = handle.await.unwrap(); + + let epsilon = Duration::from_millis(100); + assert!((wait_time..wait_time + epsilon).contains(&elapsed)); } }