11use std:: collections:: hash_map:: Entry ;
22use std:: collections:: HashMap ;
3+ use std:: future:: poll_fn;
34use std:: mem:: size_of;
45use std:: ops:: Deref ;
56use std:: path:: PathBuf ;
7+ use std:: pin:: Pin ;
68use std:: sync:: Arc ;
9+ use std:: task:: { Context , Poll } ;
710use std:: time:: { Duration , Instant } ;
811
912use bytes:: Bytes ;
1013use either:: Either ;
14+ use futures:: Future ;
1115use libsqlx:: libsql:: { LibsqlDatabase , LogCompactor , LogFile , PrimaryType , ReplicaType } ;
1216use libsqlx:: program:: Program ;
1317use libsqlx:: proxy:: { WriteProxyConnection , WriteProxyDatabase } ;
@@ -19,12 +23,12 @@ use libsqlx::{
1923use parking_lot:: Mutex ;
2024use tokio:: sync:: { mpsc, oneshot} ;
2125use tokio:: task:: { block_in_place, JoinSet } ;
22- use tokio:: time:: timeout;
26+ use tokio:: time:: { timeout, Sleep } ;
2327
2428use crate :: hrana;
2529use crate :: hrana:: http:: handle_pipeline;
2630use crate :: hrana:: http:: proto:: { PipelineRequestBody , PipelineResponseBody } ;
27- use crate :: linc:: bus:: { Dispatch } ;
31+ use crate :: linc:: bus:: Dispatch ;
2832use crate :: linc:: proto:: {
2933 BuilderStep , Enveloppe , Frames , Message , ProxyResponse , StepError , Value ,
3034} ;
@@ -50,7 +54,9 @@ pub enum AllocationMessage {
5054 Inbound ( Inbound ) ,
5155}
5256
53- pub struct RemoteDb ;
57+ pub struct RemoteDb {
58+ proxy_request_timeout_duration : Duration ,
59+ }
5460
5561#[ derive( Clone ) ]
5662pub struct RemoteConn {
@@ -62,10 +68,12 @@ struct Request {
6268 builder : Box < dyn ResultBuilder > ,
6369 pgm : Option < Program > ,
6470 next_seq_no : u32 ,
71+ timeout : Pin < Box < Sleep > > ,
6572}
6673
6774pub struct RemoteConnInner {
6875 current_req : Mutex < Option < Request > > ,
76+ request_timeout_duration : Duration ,
6977}
7078
7179impl Deref for RemoteConn {
@@ -93,6 +101,7 @@ impl libsqlx::Connection for RemoteConn {
93101 builder,
94102 pgm : Some ( program. clone ( ) ) ,
95103 next_seq_no : 0 ,
104+ timeout : Box :: pin ( tokio:: time:: sleep ( self . inner . request_timeout_duration ) ) ,
96105 } ) ,
97106 } ;
98107
@@ -111,6 +120,7 @@ impl libsqlx::Database for RemoteDb {
111120 Ok ( RemoteConn {
112121 inner : Arc :: new ( RemoteConnInner {
113122 current_req : Default :: default ( ) ,
123+ request_timeout_duration : self . proxy_request_timeout_duration ,
114124 } ) ,
115125 } )
116126 }
@@ -462,9 +472,14 @@ impl Database {
462472 frame_notifier : receiver,
463473 } )
464474 }
465- DbConfig :: Replica { primary_node_id } => {
475+ DbConfig :: Replica {
476+ primary_node_id,
477+ proxy_request_timeout_duration,
478+ } => {
466479 let rdb = LibsqlDatabase :: new_replica ( path, MAX_INJECTOR_BUFFER_CAP , ( ) ) . unwrap ( ) ;
467- let wdb = RemoteDb ;
480+ let wdb = RemoteDb {
481+ proxy_request_timeout_duration,
482+ } ;
468483 let mut db = WriteProxyDatabase :: new ( rdb, wdb, Arc :: new ( |_| ( ) ) ) ;
469484 let injector = db. injector ( ) . unwrap ( ) ;
470485 let ( sender, receiver) = mpsc:: channel ( 16 ) ;
@@ -502,7 +517,7 @@ impl Database {
502517 conn : db. connect ( ) . unwrap ( ) ,
503518 connection_id,
504519 next_req_id : 0 ,
505- primary_id : * primary_id,
520+ primary_node_id : * primary_id,
506521 database_id : DatabaseId :: from_name ( & alloc. db_name ) ,
507522 dispatcher : alloc. dispatcher . clone ( ) ,
508523 } ) ,
@@ -520,8 +535,8 @@ struct PrimaryConnection {
520535
521536#[ async_trait:: async_trait]
522537impl ConnectionHandler for PrimaryConnection {
523- fn exec_ready ( & self ) -> bool {
524- true
538+ fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < ( ) > {
539+ Poll :: Ready ( ( ) )
525540 }
526541
527542 async fn handle_exec ( & mut self , exec : ExecFn ) {
@@ -537,7 +552,7 @@ struct ReplicaConnection {
537552 conn : ProxyConnection ,
538553 connection_id : u32 ,
539554 next_req_id : u32 ,
540- primary_id : NodeId ,
555+ primary_node_id : NodeId ,
541556 database_id : DatabaseId ,
542557 dispatcher : Arc < dyn Dispatch > ,
543558}
@@ -551,16 +566,21 @@ impl ReplicaConnection {
551566 // TODO: pass actual config
552567 let config = QueryBuilderConfig { max_size : None } ;
553568 let mut finnalized = false ;
554- for step in resp. row_steps . iter ( ) {
555- if finnalized { break } ;
569+ for step in resp. row_steps . into_iter ( ) {
570+ if finnalized {
571+ break ;
572+ } ;
556573 match step {
557574 BuilderStep :: Init => req. builder . init ( & config) . unwrap ( ) ,
558575 BuilderStep :: BeginStep => req. builder . begin_step ( ) . unwrap ( ) ,
559576 BuilderStep :: FinishStep ( affected_row_count, last_insert_rowid) => req
560577 . builder
561- . finish_step ( * affected_row_count, * last_insert_rowid)
578+ . finish_step ( affected_row_count, last_insert_rowid)
579+ . unwrap ( ) ,
580+ BuilderStep :: StepError ( e) => req
581+ . builder
582+ . step_error ( todo ! ( "handle proxy step error" ) )
562583 . unwrap ( ) ,
563- BuilderStep :: StepError ( e) => req. builder . step_error ( todo ! ( ) ) . unwrap ( ) ,
564584 BuilderStep :: ColsDesc ( cols) => req
565585 . builder
566586 . cols_description ( & mut cols. iter ( ) . map ( |c| Column {
@@ -570,11 +590,15 @@ impl ReplicaConnection {
570590 . unwrap ( ) ,
571591 BuilderStep :: BeginRows => req. builder . begin_rows ( ) . unwrap ( ) ,
572592 BuilderStep :: BeginRow => req. builder . begin_row ( ) . unwrap ( ) ,
573- BuilderStep :: AddRowValue ( v) => req. builder . add_row_value ( v . into ( ) ) . unwrap ( ) ,
593+ BuilderStep :: AddRowValue ( v) => req. builder . add_row_value ( ( & v ) . into ( ) ) . unwrap ( ) ,
574594 BuilderStep :: FinishRow => req. builder . finish_row ( ) . unwrap ( ) ,
575595 BuilderStep :: FinishRows => req. builder . finish_rows ( ) . unwrap ( ) ,
576596 BuilderStep :: Finnalize { is_txn, frame_no } => {
577- let _ = req. builder . finnalize ( * is_txn, * frame_no) . unwrap ( ) ;
597+ let _ = req. builder . finnalize ( is_txn, frame_no) . unwrap ( ) ;
598+ finnalized = true ;
599+ } ,
600+ BuilderStep :: FinnalizeError ( e) => {
601+ req. builder . finnalize_error ( e) ;
578602 finnalized = true ;
579603 }
580604 }
@@ -596,9 +620,28 @@ impl ReplicaConnection {
596620
597621#[ async_trait:: async_trait]
598622impl ConnectionHandler for ReplicaConnection {
599- fn exec_ready ( & self ) -> bool {
623+ fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < ( ) > {
600624 // we are currently handling a request on this connection
601- self . conn . writer ( ) . current_req . lock ( ) . is_none ( )
625+ // self.conn.writer().current_req.timeout.poll()
626+ let mut req = self . conn . writer ( ) . current_req . lock ( ) ;
627+ let should_abort_query = match & mut * req {
628+ Some ( ref mut req) => {
629+ match req. timeout . as_mut ( ) . poll ( cx) {
630+ Poll :: Ready ( _) => {
631+ req. builder . finnalize_error ( "request timed out" . to_string ( ) ) ;
632+ true
633+ }
634+ Poll :: Pending => return Poll :: Pending ,
635+ }
636+ }
637+ None => return Poll :: Ready ( ( ) ) ,
638+ } ;
639+
640+ if should_abort_query {
641+ * req = None
642+ }
643+
644+ Poll :: Ready ( ( ) )
602645 }
603646
604647 async fn handle_exec ( & mut self , exec : ExecFn ) {
@@ -616,7 +659,7 @@ impl ConnectionHandler for ReplicaConnection {
616659 req. id = Some ( req_id) ;
617660
618661 let msg = Outbound {
619- to : self . primary_id ,
662+ to : self . primary_node_id ,
620663 enveloppe : Enveloppe {
621664 database_id : Some ( self . database_id ) ,
622665 message : Message :: ProxyRequest {
@@ -654,10 +697,10 @@ where
654697 L : ConnectionHandler ,
655698 R : ConnectionHandler ,
656699{
657- fn exec_ready ( & self ) -> bool {
700+ fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < ( ) > {
658701 match self {
659- Either :: Left ( l) => l. exec_ready ( ) ,
660- Either :: Right ( r) => r. exec_ready ( ) ,
702+ Either :: Left ( l) => l. poll_ready ( cx ) ,
703+ Either :: Right ( r) => r. poll_ready ( cx ) ,
661704 }
662705 }
663706
@@ -852,7 +895,7 @@ impl Allocation {
852895 } ;
853896 conn. execute_program ( & program, Box :: new ( builder) ) . unwrap ( ) ;
854897 } )
855- . await ;
898+ . await ;
856899 } ;
857900
858901 if self . database . is_primary ( ) {
@@ -921,19 +964,21 @@ struct Connection<C> {
921964
922965#[ async_trait:: async_trait]
923966trait ConnectionHandler : Send {
924- fn exec_ready ( & self ) -> bool ;
967+ fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < ( ) > ;
925968 async fn handle_exec ( & mut self , exec : ExecFn ) ;
926969 async fn handle_inbound ( & mut self , msg : Inbound ) ;
927970}
928971
929972impl < C : ConnectionHandler > Connection < C > {
930973 async fn run ( mut self ) -> ( NodeId , u32 ) {
931974 loop {
975+ let fut =
976+ futures:: future:: join ( self . exec . recv ( ) , poll_fn ( |cx| self . conn . poll_ready ( cx) ) ) ;
932977 tokio:: select! {
933978 Some ( inbound) = self . inbound. recv( ) => {
934979 self . conn. handle_inbound( inbound) . await ;
935980 }
936- Some ( exec) = self . exec . recv ( ) , if self . conn . exec_ready ( ) => {
981+ ( Some ( exec) , _ ) = fut => {
937982 self . conn. handle_exec( exec) . await ;
938983 } ,
939984 else => break ,
@@ -943,3 +988,65 @@ impl<C: ConnectionHandler> Connection<C> {
943988 self . id
944989 }
945990}
991+
992+ #[ cfg( test) ]
993+ mod test {
994+ use tokio:: sync:: Notify ;
995+
996+ use crate :: linc:: bus:: Bus ;
997+
998+ use super :: * ;
999+
1000+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 1 ) ]
1001+ async fn proxy_request_timeout ( ) {
1002+ let bus = Arc :: new ( Bus :: new ( 0 , |_, _| async { } ) ) ;
1003+ let _queue = bus. connect ( 1 ) ; // pretend connection to node 1
1004+ let tmp = tempfile:: TempDir :: new ( ) . unwrap ( ) ;
1005+ let read_db = LibsqlDatabase :: new_replica ( tmp. path ( ) . to_path_buf ( ) , 1 , ( ) ) . unwrap ( ) ;
1006+ let write_db = RemoteDb {
1007+ proxy_request_timeout_duration : Duration :: from_millis ( 100 ) ,
1008+ } ;
1009+ let db = WriteProxyDatabase :: new ( read_db, write_db, Arc :: new ( |_| ( ) ) ) ;
1010+ let conn = db. connect ( ) . unwrap ( ) ;
1011+ let conn = ReplicaConnection {
1012+ conn,
1013+ connection_id : 0 ,
1014+ next_req_id : 0 ,
1015+ primary_node_id : 1 ,
1016+ database_id : DatabaseId :: random ( ) ,
1017+ dispatcher : bus,
1018+ } ;
1019+
1020+ let ( exec_sender, exec) = mpsc:: channel ( 1 ) ;
1021+ let ( _inbound_sender, inbound) = mpsc:: channel ( 1 ) ;
1022+ let connection = Connection {
1023+ id : ( 0 , 0 ) ,
1024+ conn,
1025+ exec,
1026+ inbound,
1027+ } ;
1028+
1029+ let handle = tokio:: spawn ( connection. run ( ) ) ;
1030+
1031+ let notify = Arc :: new ( Notify :: new ( ) ) ;
1032+ struct Builder ( Arc < Notify > ) ;
1033+ impl ResultBuilder for Builder {
1034+ fn finnalize_error ( & mut self , _e : String ) {
1035+ self . 0 . notify_waiters ( )
1036+ }
1037+ }
1038+
1039+ let builder = Box :: new ( Builder ( notify. clone ( ) ) ) ;
1040+ exec_sender
1041+ . send ( Box :: new ( move |conn| {
1042+ conn. execute_program ( & Program :: seq ( & [ "create table test (c)" ] ) , builder)
1043+ . unwrap ( ) ;
1044+ } ) )
1045+ . await
1046+ . unwrap ( ) ;
1047+
1048+ notify. notified ( ) . await ;
1049+
1050+ handle. abort ( ) ;
1051+ }
1052+ }
0 commit comments