@@ -265,10 +265,13 @@ impl<W: WalHook> Default for TxnState<W> {
265
265
unsafe extern "C" fn busy_handler < W : WalHook > ( state : * mut c_void , _retries : c_int ) -> c_int {
266
266
let state = & * ( state as * mut TxnState < W > ) ;
267
267
let lock = state. slot . read ( ) ;
268
- // fast path
269
- if lock. is_none ( ) {
270
- return 1 ;
271
- }
268
+ // we take a reference to the slot we will attempt to steal. this is to make sure that we
269
+ // actually steal the correct lock.
270
+ let slot = match & * lock {
271
+ Some ( slot) => slot. clone ( ) ,
272
+ // fast path: there is no slot, try to acquire the lock again
273
+ None => return 1 ,
274
+ } ;
272
275
273
276
tokio:: runtime:: Handle :: current ( ) . block_on ( async move {
274
277
let timeout = {
@@ -279,20 +282,28 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
279
282
} ;
280
283
281
284
tokio:: select! {
285
+ // The connection has notified us that it's txn has terminated, try to acquire again
282
286
_ = state. notify. notified( ) => 1 ,
287
+ // the current holder of the transaction has timedout, we will attempt to steal their
288
+ // lock.
283
289
_ = timeout => {
284
- // attempt to steal the lock
285
- let mut lock = state. slot. write( ) ;
286
- // we attempt to take the slot, and steal the transaction from the other
287
- // connection
288
- if let Some ( slot) = lock. take( ) {
289
- if Instant :: now( ) >= slot. timeout_at {
290
- tracing:: info!( "stole transaction lock" ) ;
290
+ // only a single connection gets to steal the lock, others retry
291
+ if let Some ( mut lock) = state. slot. try_write( ) {
292
+ // We check that slot wasn't already stolen, and that their is still a slot.
293
+ // The ordering is relaxed because the atomic is only set under the slot lock.
294
+ if slot. is_stolen. compare_exchange( false , true , Ordering :: Relaxed , Ordering :: Relaxed ) . is_ok( ) {
295
+ // The connection holding the current txn will sets itsef as stolen when it
296
+ // detects a timeout, so if we arrive to this point, then there is
297
+ // necessarily a slot, and this slot has to be the one we attempted to
298
+ // steal.
299
+ assert!( lock. take( ) . is_some( ) ) ;
300
+
291
301
let conn = slot. conn. lock( ) ;
292
302
// we have a lock on the connection, we don't need mode than a
293
303
// Relaxed store.
294
- slot. is_stolen. store( true , std:: sync:: atomic:: Ordering :: Relaxed ) ;
295
304
conn. rollback( ) ;
305
+
306
+ tracing:: info!( "stole transaction lock" ) ;
296
307
}
297
308
}
298
309
1
@@ -373,6 +384,8 @@ impl<W: WalHook> Connection<W> {
373
384
374
385
if let Some ( slot) = & lock. slot {
375
386
if slot. is_stolen . load ( Ordering :: Relaxed ) || Instant :: now ( ) > slot. timeout_at {
387
+ // we mark ourselves as stolen to notify any waiting lock thief.
388
+ slot. is_stolen . store ( true , Ordering :: Relaxed ) ;
376
389
lock. rollback ( ) ;
377
390
has_timeout = true ;
378
391
}
@@ -419,7 +432,12 @@ impl<W: WalHook> Connection<W> {
419
432
420
433
builder. finish ( * this. lock ( ) . current_frame_no_receiver . borrow_and_update ( ) ) ?;
421
434
422
- let state = if matches ! ( this. lock( ) . conn. transaction_state( Some ( DatabaseName :: Main ) ) ?, Tx :: Read | Tx :: Write ) {
435
+ let state = if matches ! (
436
+ this. lock( )
437
+ . conn
438
+ . transaction_state( Some ( DatabaseName :: Main ) ) ?,
439
+ Tx :: Read | Tx :: Write
440
+ ) {
423
441
State :: Txn
424
442
} else {
425
443
State :: Init
@@ -697,13 +715,11 @@ where
697
715
698
716
#[ cfg( test) ]
699
717
mod test {
700
- use insta:: assert_json_snapshot;
701
718
use itertools:: Itertools ;
702
719
use sqld_libsql_bindings:: wal_hook:: TRANSPARENT_METHODS ;
703
720
use tempfile:: tempdir;
704
721
use tokio:: task:: JoinSet ;
705
722
706
- use crate :: connection:: Connection as _;
707
723
use crate :: query_result_builder:: test:: { test_driver, TestBuilder } ;
708
724
use crate :: query_result_builder:: QueryResultBuilder ;
709
725
use crate :: DEFAULT_AUTO_CHECKPOINT ;
@@ -740,7 +756,7 @@ mod test {
740
756
}
741
757
742
758
#[ tokio:: test]
743
- async fn txn_stealing ( ) {
759
+ async fn txn_timeout_no_stealing ( ) {
744
760
let tmp = tempdir ( ) . unwrap ( ) ;
745
761
let make_conn = MakeLibSqlConn :: new (
746
762
tmp. path ( ) . into ( ) ,
@@ -757,122 +773,75 @@ mod test {
757
773
. await
758
774
. unwrap ( ) ;
759
775
760
- let conn1 = make_conn. make_connection ( ) . await . unwrap ( ) ;
761
- let conn2 = make_conn. make_connection ( ) . await . unwrap ( ) ;
762
-
763
- let mut join_set = JoinSet :: new ( ) ;
764
- let notify = Arc :: new ( Notify :: new ( ) ) ;
765
-
766
- join_set. spawn ( {
767
- let notify = notify. clone ( ) ;
768
- async move {
769
- // 1. take an exclusive lock
770
- let conn = conn1. inner . clone ( ) ;
771
- let res = tokio:: task:: spawn_blocking ( || {
772
- Connection :: run (
773
- conn,
774
- Program :: seq ( & [ "BEGIN EXCLUSIVE" ] ) ,
775
- TestBuilder :: default ( ) ,
776
- )
777
- . unwrap ( )
778
- } )
779
- . await
780
- . unwrap ( ) ;
781
- assert ! ( res. 0 . into_ret( ) . into_iter( ) . all( |x| x. is_ok( ) ) ) ;
782
- assert_eq ! ( res. 1 , State :: Txn ) ;
783
- assert ! ( conn1. inner. lock( ) . slot. is_some( ) ) ;
784
- // 2. notify other conn that lock was acquired
785
- notify. notify_one ( ) ;
786
- // 6. wait till other connection steals the lock
787
- notify. notified ( ) . await ;
788
- // 7. get an error because txn timedout
789
- let conn = conn1. inner . clone ( ) ;
790
- // our lock was stolen
791
- assert ! ( conn1
792
- . inner
793
- . lock( )
794
- . slot
795
- . as_ref( )
796
- . unwrap( )
797
- . is_stolen
798
- . load( Ordering :: Relaxed ) ) ;
799
- let res = tokio:: task:: spawn_blocking ( || {
800
- Connection :: run (
801
- conn,
802
- Program :: seq ( & [ "CREATE TABLE TEST (x)" ] ) ,
803
- TestBuilder :: default ( ) ,
804
- )
805
- . unwrap ( )
806
- } )
807
- . await
808
- . unwrap ( ) ;
776
+ tokio:: time:: pause ( ) ;
777
+ let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
778
+ let ( _builder, state) = Connection :: run (
779
+ conn. inner . clone ( ) ,
780
+ Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
781
+ TestBuilder :: default ( ) ,
782
+ )
783
+ . unwrap ( ) ;
784
+ assert_eq ! ( state, State :: Txn ) ;
809
785
810
- assert ! ( matches! ( res . 0 . into_ret ( ) [ 0 ] , Err ( Error :: LibSqlTxTimeout ) ) ) ;
786
+ tokio :: time :: advance ( TXN_TIMEOUT * 2 ) . await ;
811
787
812
- let before = Instant :: now ( ) ;
813
- let conn = conn1. inner . clone ( ) ;
814
- // 8. try to acquire lock again
815
- let res = tokio:: task:: spawn_blocking ( || {
816
- Connection :: run (
817
- conn,
818
- Program :: seq ( & [ "CREATE TABLE TEST (x)" ] ) ,
819
- TestBuilder :: default ( ) ,
820
- )
821
- . unwrap ( )
822
- } )
823
- . await
824
- . unwrap ( ) ;
788
+ let ( builder, state) = Connection :: run (
789
+ conn. inner . clone ( ) ,
790
+ Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
791
+ TestBuilder :: default ( ) ,
792
+ )
793
+ . unwrap ( ) ;
794
+ assert_eq ! ( state, State :: Init ) ;
795
+ assert ! ( matches!( builder. into_ret( ) [ 0 ] , Err ( Error :: LibSqlTxTimeout ) ) ) ;
796
+ }
825
797
826
- assert ! ( res. 0 . into_ret( ) . into_iter( ) . all( |x| x. is_ok( ) ) ) ;
827
- // the lock must have been released before the timeout
828
- assert ! ( before. elapsed( ) < TXN_TIMEOUT ) ;
829
- notify. notify_one ( ) ;
830
- }
831
- } ) ;
798
+ #[ tokio:: test]
799
+ /// A bunch of txn try to acquire the lock, and never release it. They will try to steal the
800
+ /// lock one after the other. All txn should eventually acquire the write lock
801
+ async fn serialized_txn_timeouts ( ) {
802
+ let tmp = tempdir ( ) . unwrap ( ) ;
803
+ let make_conn = MakeLibSqlConn :: new (
804
+ tmp. path ( ) . into ( ) ,
805
+ & TRANSPARENT_METHODS ,
806
+ || ( ) ,
807
+ Default :: default ( ) ,
808
+ Arc :: new ( DatabaseConfigStore :: load ( tmp. path ( ) ) . unwrap ( ) ) ,
809
+ Arc :: new ( [ ] ) ,
810
+ 100000000 ,
811
+ 100000000 ,
812
+ DEFAULT_AUTO_CHECKPOINT ,
813
+ watch:: channel ( None ) . 1 ,
814
+ )
815
+ . await
816
+ . unwrap ( ) ;
832
817
833
- join_set. spawn ( {
834
- let notify = notify. clone ( ) ;
835
- async move {
836
- // 3. wait for other connection to acquire lock
837
- notify. notified ( ) . await ;
838
- // 4. try to acquire lock as well
839
- let conn = conn2. inner . clone ( ) ;
840
- tokio:: task:: spawn_blocking ( || {
841
- Connection :: run (
842
- conn,
843
- Program :: seq ( & [ "BEGIN EXCLUSIVE" ] ) ,
844
- TestBuilder :: default ( ) ,
845
- )
846
- . unwrap ( ) ;
847
- } )
848
- . await
818
+ let mut set = JoinSet :: new ( ) ;
819
+ for _ in 0 ..10 {
820
+ let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
821
+ set. spawn_blocking ( move || {
822
+ let ( builder, state) = Connection :: run (
823
+ conn. inner ,
824
+ Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
825
+ TestBuilder :: default ( ) ,
826
+ )
849
827
. unwrap ( ) ;
850
- // 5. notify other that we could acquire the lock
851
- notify. notify_one ( ) ;
852
-
853
- // 9. rollback before timeout
854
- tokio:: time:: sleep ( TXN_TIMEOUT / 2 ) . await ;
855
- let conn = conn2. inner . clone ( ) ;
856
- let slot = conn2. inner . lock ( ) . slot . as_ref ( ) . unwrap ( ) . clone ( ) ;
857
- tokio:: task:: spawn_blocking ( || {
858
- Connection :: run ( conn, Program :: seq ( & [ "ROLLBACK" ] ) , TestBuilder :: default ( ) )
859
- . unwrap ( ) ;
860
- } )
861
- . await
862
- . unwrap ( ) ;
863
- // rolling back caused to slot to b removed
864
- assert ! ( conn2. inner. lock( ) . slot. is_none( ) ) ;
865
- // the lock was *not* stolen
866
- notify. notified ( ) . await ;
867
- assert ! ( !slot. is_stolen. load( Ordering :: Relaxed ) ) ;
868
- }
869
- } ) ;
828
+ assert_eq ! ( state, State :: Txn ) ;
829
+ assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
830
+ } ) ;
831
+ }
870
832
871
- while join_set. join_next ( ) . await . is_some ( ) { }
833
+ tokio:: time:: pause ( ) ;
834
+
835
+ while let Some ( ret) = set. join_next ( ) . await {
836
+ assert ! ( ret. is_ok( ) ) ;
837
+ // advance time by a bit more than the txn timeout
838
+ tokio:: time:: advance ( TXN_TIMEOUT + Duration :: from_millis ( 100 ) ) . await ;
839
+ }
872
840
}
873
841
874
842
#[ tokio:: test]
875
- async fn txn_timeout_no_stealing ( ) {
843
+ /// verify that releasing a txn before the timeout
844
+ async fn release_before_timeout ( ) {
876
845
let tmp = tempdir ( ) . unwrap ( ) ;
877
846
let make_conn = MakeLibSqlConn :: new (
878
847
tmp. path ( ) . into ( ) ,
@@ -886,18 +855,62 @@ mod test {
886
855
DEFAULT_AUTO_CHECKPOINT ,
887
856
watch:: channel ( None ) . 1 ,
888
857
)
889
- . await
890
- . unwrap ( ) ;
858
+ . await
859
+ . unwrap ( ) ;
891
860
892
- tokio:: time:: pause ( ) ;
893
- let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
894
- let ( _builder, state) = Connection :: run ( conn. inner . clone ( ) , Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) , TestBuilder :: default ( ) ) . unwrap ( ) ;
895
- assert_eq ! ( state, State :: Txn ) ;
861
+ let conn1 = make_conn. make_connection ( ) . await . unwrap ( ) ;
862
+ tokio:: task:: spawn_blocking ( {
863
+ let conn = conn1. inner . clone ( ) ;
864
+ move || {
865
+ let ( builder, state) = Connection :: run (
866
+ conn,
867
+ Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
868
+ TestBuilder :: default ( ) ,
869
+ )
870
+ . unwrap ( ) ;
871
+ assert_eq ! ( state, State :: Txn ) ;
872
+ assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
873
+ }
874
+ } )
875
+ . await
876
+ . unwrap ( ) ;
896
877
897
- tokio:: time:: advance ( TXN_TIMEOUT * 2 ) . await ;
878
+ let conn2 = make_conn. make_connection ( ) . await . unwrap ( ) ;
879
+ let handle = tokio:: task:: spawn_blocking ( {
880
+ let conn = conn2. inner . clone ( ) ;
881
+ move || {
882
+ let before = Instant :: now ( ) ;
883
+ let ( builder, state) = Connection :: run (
884
+ conn,
885
+ Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
886
+ TestBuilder :: default ( ) ,
887
+ )
888
+ . unwrap ( ) ;
889
+ assert_eq ! ( state, State :: Txn ) ;
890
+ assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
891
+ before. elapsed ( )
892
+ }
893
+ } ) ;
898
894
899
- let ( builder, state) = Connection :: run ( conn. inner . clone ( ) , Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) , TestBuilder :: default ( ) ) . unwrap ( ) ;
900
- assert_eq ! ( state, State :: Init ) ;
901
- assert ! ( matches!( builder. into_ret( ) [ 0 ] , Err ( Error :: LibSqlTxTimeout ) ) ) ;
895
+ let wait_time = TXN_TIMEOUT / 10 ;
896
+ tokio:: time:: sleep ( wait_time) . await ;
897
+
898
+ tokio:: task:: spawn_blocking ( {
899
+ let conn = conn1. inner . clone ( ) ;
900
+ move || {
901
+ let ( builder, state) =
902
+ Connection :: run ( conn, Program :: seq ( & [ "COMMIT" ] ) , TestBuilder :: default ( ) )
903
+ . unwrap ( ) ;
904
+ assert_eq ! ( state, State :: Init ) ;
905
+ assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
906
+ }
907
+ } )
908
+ . await
909
+ . unwrap ( ) ;
910
+
911
+ let elapsed = handle. await . unwrap ( ) ;
912
+
913
+ let epsilon = Duration :: from_millis ( 100 ) ;
914
+ assert ! ( ( wait_time..wait_time + epsilon) . contains( & elapsed) ) ;
902
915
}
903
916
}
0 commit comments