Skip to content
This repository has been archived by the owner on Oct 18, 2023. It is now read-only.

Commit

Permalink
fix deadlock and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MarinPostma committed Oct 3, 2023
1 parent 61d61dd commit 7c6a4fc
Showing 1 changed file with 146 additions and 133 deletions.
279 changes: 146 additions & 133 deletions sqld/src/connection/libsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,13 @@ impl<W: WalHook> Default for TxnState<W> {
unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_int) -> c_int {
let state = &*(state as *mut TxnState<W>);
let lock = state.slot.read();
// fast path
if lock.is_none() {
return 1;
}
// we take a reference to the slot we will attempt to steal. this is to make sure that we
// actually steal the correct lock.
let slot = match &*lock {
Some(slot) => slot.clone(),
// fast path: there is no slot, try to acquire the lock again
None => return 1,
};

tokio::runtime::Handle::current().block_on(async move {
let timeout = {
Expand All @@ -279,20 +282,28 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
};

tokio::select! {
// The connection has notified us that it's txn has terminated, try to acquire again
_ = state.notify.notified() => 1,
// the current holder of the transaction has timedout, we will attempt to steal their
// lock.
_ = timeout => {
// attempt to steal the lock
let mut lock = state.slot.write();
// we attempt to take the slot, and steal the transaction from the other
// connection
if let Some(slot) = lock.take() {
if Instant::now() >= slot.timeout_at {
tracing::info!("stole transaction lock");
// only a single connection gets to steal the lock, others retry
if let Some(mut lock) = state.slot.try_write() {
// We check that slot wasn't already stolen, and that their is still a slot.
// The ordering is relaxed because the atomic is only set under the slot lock.
if slot.is_stolen.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed).is_ok() {
// The connection holding the current txn will sets itsef as stolen when it
// detects a timeout, so if we arrive to this point, then there is
// necessarily a slot, and this slot has to be the one we attempted to
// steal.
assert!(lock.take().is_some());

let conn = slot.conn.lock();
// we have a lock on the connection, we don't need mode than a
// Relaxed store.
slot.is_stolen.store(true, std::sync::atomic::Ordering::Relaxed);
conn.rollback();

tracing::info!("stole transaction lock");
}
}
1
Expand Down Expand Up @@ -373,6 +384,8 @@ impl<W: WalHook> Connection<W> {

if let Some(slot) = &lock.slot {
if slot.is_stolen.load(Ordering::Relaxed) || Instant::now() > slot.timeout_at {
// we mark ourselves as stolen to notify any waiting lock thief.
slot.is_stolen.store(true, Ordering::Relaxed);
lock.rollback();
has_timeout = true;
}
Expand Down Expand Up @@ -419,7 +432,12 @@ impl<W: WalHook> Connection<W> {

builder.finish(*this.lock().current_frame_no_receiver.borrow_and_update())?;

let state = if matches!(this.lock().conn.transaction_state(Some(DatabaseName::Main))?, Tx::Read | Tx::Write) {
let state = if matches!(
this.lock()
.conn
.transaction_state(Some(DatabaseName::Main))?,
Tx::Read | Tx::Write
) {
State::Txn
} else {
State::Init
Expand Down Expand Up @@ -697,13 +715,11 @@ where

#[cfg(test)]
mod test {
use insta::assert_json_snapshot;
use itertools::Itertools;
use sqld_libsql_bindings::wal_hook::TRANSPARENT_METHODS;
use tempfile::tempdir;
use tokio::task::JoinSet;

use crate::connection::Connection as _;
use crate::query_result_builder::test::{test_driver, TestBuilder};
use crate::query_result_builder::QueryResultBuilder;
use crate::DEFAULT_AUTO_CHECKPOINT;
Expand Down Expand Up @@ -740,7 +756,7 @@ mod test {
}

#[tokio::test]
async fn txn_stealing() {
async fn txn_timeout_no_stealing() {
let tmp = tempdir().unwrap();
let make_conn = MakeLibSqlConn::new(
tmp.path().into(),
Expand All @@ -757,122 +773,75 @@ mod test {
.await
.unwrap();

let conn1 = make_conn.make_connection().await.unwrap();
let conn2 = make_conn.make_connection().await.unwrap();

let mut join_set = JoinSet::new();
let notify = Arc::new(Notify::new());

join_set.spawn({
let notify = notify.clone();
async move {
// 1. take an exclusive lock
let conn = conn1.inner.clone();
let res = tokio::task::spawn_blocking(|| {
Connection::run(
conn,
Program::seq(&["BEGIN EXCLUSIVE"]),
TestBuilder::default(),
)
.unwrap()
})
.await
.unwrap();
assert!(res.0.into_ret().into_iter().all(|x| x.is_ok()));
assert_eq!(res.1, State::Txn);
assert!(conn1.inner.lock().slot.is_some());
// 2. notify other conn that lock was acquired
notify.notify_one();
// 6. wait till other connection steals the lock
notify.notified().await;
// 7. get an error because txn timedout
let conn = conn1.inner.clone();
// our lock was stolen
assert!(conn1
.inner
.lock()
.slot
.as_ref()
.unwrap()
.is_stolen
.load(Ordering::Relaxed));
let res = tokio::task::spawn_blocking(|| {
Connection::run(
conn,
Program::seq(&["CREATE TABLE TEST (x)"]),
TestBuilder::default(),
)
.unwrap()
})
.await
.unwrap();
tokio::time::pause();
let conn = make_conn.make_connection().await.unwrap();
let (_builder, state) = Connection::run(
conn.inner.clone(),
Program::seq(&["BEGIN IMMEDIATE"]),
TestBuilder::default(),
)
.unwrap();
assert_eq!(state, State::Txn);

assert!(matches!(res.0.into_ret()[0], Err(Error::LibSqlTxTimeout)));
tokio::time::advance(TXN_TIMEOUT * 2).await;

let before = Instant::now();
let conn = conn1.inner.clone();
// 8. try to acquire lock again
let res = tokio::task::spawn_blocking(|| {
Connection::run(
conn,
Program::seq(&["CREATE TABLE TEST (x)"]),
TestBuilder::default(),
)
.unwrap()
})
.await
.unwrap();
let (builder, state) = Connection::run(
conn.inner.clone(),
Program::seq(&["BEGIN IMMEDIATE"]),
TestBuilder::default(),
)
.unwrap();
assert_eq!(state, State::Init);
assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout)));
}

assert!(res.0.into_ret().into_iter().all(|x| x.is_ok()));
// the lock must have been released before the timeout
assert!(before.elapsed() < TXN_TIMEOUT);
notify.notify_one();
}
});
#[tokio::test]
/// A bunch of txn try to acquire the lock, and never release it. They will try to steal the
/// lock one after the other. All txn should eventually acquire the write lock
async fn serialized_txn_timeouts() {
let tmp = tempdir().unwrap();
let make_conn = MakeLibSqlConn::new(
tmp.path().into(),
&TRANSPARENT_METHODS,
|| (),
Default::default(),
Arc::new(DatabaseConfigStore::load(tmp.path()).unwrap()),
Arc::new([]),
100000000,
100000000,
DEFAULT_AUTO_CHECKPOINT,
watch::channel(None).1,
)
.await
.unwrap();

join_set.spawn({
let notify = notify.clone();
async move {
// 3. wait for other connection to acquire lock
notify.notified().await;
// 4. try to acquire lock as well
let conn = conn2.inner.clone();
tokio::task::spawn_blocking(|| {
Connection::run(
conn,
Program::seq(&["BEGIN EXCLUSIVE"]),
TestBuilder::default(),
)
.unwrap();
})
.await
let mut set = JoinSet::new();
for _ in 0..10 {
let conn = make_conn.make_connection().await.unwrap();
set.spawn_blocking(move || {
let (builder, state) = Connection::run(
conn.inner,
Program::seq(&["BEGIN IMMEDIATE"]),
TestBuilder::default(),
)
.unwrap();
// 5. notify other that we could acquire the lock
notify.notify_one();

// 9. rollback before timeout
tokio::time::sleep(TXN_TIMEOUT / 2).await;
let conn = conn2.inner.clone();
let slot = conn2.inner.lock().slot.as_ref().unwrap().clone();
tokio::task::spawn_blocking(|| {
Connection::run(conn, Program::seq(&["ROLLBACK"]), TestBuilder::default())
.unwrap();
})
.await
.unwrap();
// rolling back caused to slot to b removed
assert!(conn2.inner.lock().slot.is_none());
// the lock was *not* stolen
notify.notified().await;
assert!(!slot.is_stolen.load(Ordering::Relaxed));
}
});
assert_eq!(state, State::Txn);
assert!(builder.into_ret()[0].is_ok());
});
}

while join_set.join_next().await.is_some() {}
tokio::time::pause();

while let Some(ret) = set.join_next().await {
assert!(ret.is_ok());
// advance time by a bit more than the txn timeout
tokio::time::advance(TXN_TIMEOUT + Duration::from_millis(100)).await;
}
}

#[tokio::test]
async fn txn_timeout_no_stealing() {
/// verify that releasing a txn before the timeout
async fn release_before_timeout() {
let tmp = tempdir().unwrap();
let make_conn = MakeLibSqlConn::new(
tmp.path().into(),
Expand All @@ -886,18 +855,62 @@ mod test {
DEFAULT_AUTO_CHECKPOINT,
watch::channel(None).1,
)
.await
.unwrap();
.await
.unwrap();

tokio::time::pause();
let conn = make_conn.make_connection().await.unwrap();
let (_builder, state) = Connection::run(conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default()).unwrap();
assert_eq!(state, State::Txn);
let conn1 = make_conn.make_connection().await.unwrap();
tokio::task::spawn_blocking({
let conn = conn1.inner.clone();
move || {
let (builder, state) = Connection::run(
conn,
Program::seq(&["BEGIN IMMEDIATE"]),
TestBuilder::default(),
)
.unwrap();
assert_eq!(state, State::Txn);
assert!(builder.into_ret()[0].is_ok());
}
})
.await
.unwrap();

tokio::time::advance(TXN_TIMEOUT * 2).await;
let conn2 = make_conn.make_connection().await.unwrap();
let handle = tokio::task::spawn_blocking({
let conn = conn2.inner.clone();
move || {
let before = Instant::now();
let (builder, state) = Connection::run(
conn,
Program::seq(&["BEGIN IMMEDIATE"]),
TestBuilder::default(),
)
.unwrap();
assert_eq!(state, State::Txn);
assert!(builder.into_ret()[0].is_ok());
before.elapsed()
}
});

let (builder, state) = Connection::run(conn.inner.clone(), Program::seq(&["BEGIN IMMEDIATE"]), TestBuilder::default()).unwrap();
assert_eq!(state, State::Init);
assert!(matches!(builder.into_ret()[0], Err(Error::LibSqlTxTimeout)));
let wait_time = TXN_TIMEOUT / 10;
tokio::time::sleep(wait_time).await;

tokio::task::spawn_blocking({
let conn = conn1.inner.clone();
move || {
let (builder, state) =
Connection::run(conn, Program::seq(&["COMMIT"]), TestBuilder::default())
.unwrap();
assert_eq!(state, State::Init);
assert!(builder.into_ret()[0].is_ok());
}
})
.await
.unwrap();

let elapsed = handle.await.unwrap();

let epsilon = Duration::from_millis(100);
assert!((wait_time..wait_time + epsilon).contains(&elapsed));
}
}

0 comments on commit 7c6a4fc

Please sign in to comment.