1
- use std:: ops:: DerefMut ;
2
1
use std:: {
3
2
collections:: HashMap ,
4
3
io,
@@ -142,6 +141,7 @@ pub struct ClientState {
142
141
max_frame_size : u32 ,
143
142
last_heatbeat : Instant ,
144
143
heartbeat_task : Option < task:: TaskHandle > ,
144
+ last_received_message : Arc < RwLock < Instant > > ,
145
145
}
146
146
147
147
/// Raw API for taking to RabbitMQ stream
@@ -165,8 +165,9 @@ impl Client {
165
165
166
166
let ( sender, receiver) = Client :: create_connection ( & broker) . await ?;
167
167
168
- let dispatcher = Dispatcher :: new ( ) ;
168
+ let last_received_message = Arc :: new ( RwLock :: new ( Instant :: now ( ) ) ) ;
169
169
170
+ let dispatcher = Dispatcher :: new ( ) ;
170
171
let state = ClientState {
171
172
server_properties : HashMap :: new ( ) ,
172
173
connection_properties : HashMap :: new ( ) ,
@@ -175,6 +176,7 @@ impl Client {
175
176
max_frame_size : broker. max_frame_size ,
176
177
last_heatbeat : Instant :: now ( ) ,
177
178
heartbeat_task : None ,
179
+ last_received_message : last_received_message. clone ( ) ,
178
180
} ;
179
181
let mut client = Client {
180
182
dispatcher,
@@ -483,6 +485,14 @@ impl Client {
483
485
self . filtering_supported
484
486
}
485
487
488
+ pub async fn set_heartbeat ( & self , heartbeat : u32 ) {
489
+ let mut state = self . state . write ( ) . await ;
490
+ state. heartbeat = heartbeat;
491
+ // Eventually, this drops the previous heartbeat task
492
+ state. heartbeat_task =
493
+ self . start_hearbeat_task ( heartbeat, state. last_received_message . clone ( ) ) ;
494
+ }
495
+
486
496
async fn create_connection (
487
497
broker : & ClientOptions ,
488
498
) -> Result <
@@ -500,6 +510,7 @@ impl Client {
500
510
501
511
Ok ( ( tx, rx) )
502
512
}
513
+
503
514
async fn initialize < T > ( & mut self , receiver : ChannelReceiver < T > ) -> Result < ( ) , ClientError >
504
515
where
505
516
T : Stream < Item = Result < Response , ClientError > > + Unpin + Send ,
@@ -523,7 +534,8 @@ impl Client {
523
534
524
535
// Start heartbeat task after connection is established
525
536
let mut state = self . state . write ( ) . await ;
526
- state. heartbeat_task = self . start_hearbeat_task ( state. heartbeat ) ;
537
+ state. heartbeat_task =
538
+ self . start_hearbeat_task ( state. heartbeat , state. last_received_message . clone ( ) ) ;
527
539
drop ( state) ;
528
540
529
541
Ok ( ( ) )
@@ -664,7 +676,8 @@ impl Client {
664
676
665
677
if state. heartbeat_task . take ( ) . is_some ( ) {
666
678
// Start heartbeat task after connection is established
667
- state. heartbeat_task = self . start_hearbeat_task ( state. heartbeat ) ;
679
+ state. heartbeat_task =
680
+ self . start_hearbeat_task ( state. heartbeat , state. last_received_message . clone ( ) ) ;
668
681
}
669
682
670
683
drop ( state) ;
@@ -677,14 +690,22 @@ impl Client {
677
690
self . tune_notifier . notify_one ( ) ;
678
691
}
679
692
680
- fn start_hearbeat_task ( & self , heartbeat : u32 ) -> Option < task:: TaskHandle > {
693
+ fn start_hearbeat_task (
694
+ & self ,
695
+ heartbeat : u32 ,
696
+ last_received_message : Arc < RwLock < Instant > > ,
697
+ ) -> Option < task:: TaskHandle > {
681
698
if heartbeat == 0 {
682
699
return None ;
683
700
}
684
701
let heartbeat_interval = ( heartbeat / 2 ) . max ( 1 ) ;
685
702
let channel = self . channel . clone ( ) ;
686
703
704
+ let client = self . clone ( ) ;
705
+
687
706
let heartbeat_task: task:: TaskHandle = tokio:: spawn ( async move {
707
+ let timeout_threashold = u64:: from ( heartbeat * 4 ) ;
708
+
688
709
loop {
689
710
trace ! ( "Sending heartbeat" ) ;
690
711
if channel
@@ -695,7 +716,20 @@ impl Client {
695
716
break ;
696
717
}
697
718
tokio:: time:: sleep ( Duration :: from_secs ( heartbeat_interval. into ( ) ) ) . await ;
719
+
720
+ let now = Instant :: now ( ) ;
721
+ let last_message = last_received_message. read ( ) . await ;
722
+ if now. duration_since ( * last_message) >= Duration :: from_secs ( timeout_threashold) {
723
+ warn ! ( "Heartbeat timeout reached. Force closing connection." ) ;
724
+ if !client. is_closed ( ) {
725
+ if let Err ( e) = client. close ( ) . await {
726
+ warn ! ( "Error closing client: {}" , e) ;
727
+ }
728
+ }
729
+ break ;
730
+ }
698
731
}
732
+
699
733
warn ! ( "Heartbeat task stopped. Force closing connection" ) ;
700
734
} )
701
735
. into ( ) ;
@@ -725,17 +759,28 @@ impl Client {
725
759
impl MessageHandler for Client {
726
760
async fn handle_message ( & self , item : MessageResult ) -> RabbitMQStreamResult < ( ) > {
727
761
match & item {
728
- Some ( Ok ( response) ) => match response. kind_ref ( ) {
729
- ResponseKind :: Tunes ( tune) => self . handle_tune_command ( tune) . await ,
730
- ResponseKind :: Heartbeat ( _) => self . handle_heart_beat_command ( ) . await ,
731
- _ => {
732
- if let Some ( handler) = self . state . read ( ) . await . handler . as_ref ( ) {
733
- let handler = handler. clone ( ) ;
734
-
735
- tokio:: task:: spawn ( async move { handler. handle_message ( item) . await } ) ;
762
+ Some ( Ok ( response) ) => {
763
+ // Update last received message time: needed for heartbeat task
764
+ {
765
+ let s = self . state . read ( ) . await ;
766
+ let mut last_received_message = s. last_received_message . write ( ) . await ;
767
+ * last_received_message = Instant :: now ( ) ;
768
+ drop ( last_received_message) ;
769
+ drop ( s) ;
770
+ }
771
+
772
+ match response. kind_ref ( ) {
773
+ ResponseKind :: Tunes ( tune) => self . handle_tune_command ( tune) . await ,
774
+ ResponseKind :: Heartbeat ( _) => self . handle_heart_beat_command ( ) . await ,
775
+ _ => {
776
+ if let Some ( handler) = self . state . read ( ) . await . handler . as_ref ( ) {
777
+ let handler = handler. clone ( ) ;
778
+
779
+ tokio:: task:: spawn ( async move { handler. handle_message ( item) . await } ) ;
780
+ }
736
781
}
737
782
}
738
- } ,
783
+ }
739
784
Some ( Err ( err) ) => {
740
785
trace ! ( ?err) ;
741
786
if let Some ( handler) = self . state . read ( ) . await . handler . as_ref ( ) {
0 commit comments