Skip to content

Commit af13e7c

Browse files
committed
Implement heartbeat timeout
1 parent 7eef644 commit af13e7c

File tree

4 files changed

+71
-24
lines changed

4 files changed

+71
-24
lines changed

src/client/mod.rs

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::ops::DerefMut;
21
use std::{
32
collections::HashMap,
43
io,
@@ -142,6 +141,7 @@ pub struct ClientState {
142141
max_frame_size: u32,
143142
last_heatbeat: Instant,
144143
heartbeat_task: Option<task::TaskHandle>,
144+
last_received_message: Arc<RwLock<Instant>>,
145145
}
146146

147147
/// Raw API for taking to RabbitMQ stream
@@ -165,8 +165,9 @@ impl Client {
165165

166166
let (sender, receiver) = Client::create_connection(&broker).await?;
167167

168-
let dispatcher = Dispatcher::new();
168+
let last_received_message = Arc::new(RwLock::new(Instant::now()));
169169

170+
let dispatcher = Dispatcher::new();
170171
let state = ClientState {
171172
server_properties: HashMap::new(),
172173
connection_properties: HashMap::new(),
@@ -175,6 +176,7 @@ impl Client {
175176
max_frame_size: broker.max_frame_size,
176177
last_heatbeat: Instant::now(),
177178
heartbeat_task: None,
179+
last_received_message: last_received_message.clone(),
178180
};
179181
let mut client = Client {
180182
dispatcher,
@@ -483,6 +485,14 @@ impl Client {
483485
self.filtering_supported
484486
}
485487

488+
pub async fn set_heartbeat(&self, heartbeat: u32) {
489+
let mut state = self.state.write().await;
490+
state.heartbeat = heartbeat;
491+
// Eventually, this drops the previous heartbeat task
492+
state.heartbeat_task =
493+
self.start_hearbeat_task(heartbeat, state.last_received_message.clone());
494+
}
495+
486496
async fn create_connection(
487497
broker: &ClientOptions,
488498
) -> Result<
@@ -500,6 +510,7 @@ impl Client {
500510

501511
Ok((tx, rx))
502512
}
513+
503514
async fn initialize<T>(&mut self, receiver: ChannelReceiver<T>) -> Result<(), ClientError>
504515
where
505516
T: Stream<Item = Result<Response, ClientError>> + Unpin + Send,
@@ -523,7 +534,8 @@ impl Client {
523534

524535
// Start heartbeat task after connection is established
525536
let mut state = self.state.write().await;
526-
state.heartbeat_task = self.start_hearbeat_task(state.heartbeat);
537+
state.heartbeat_task =
538+
self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
527539
drop(state);
528540

529541
Ok(())
@@ -664,7 +676,8 @@ impl Client {
664676

665677
if state.heartbeat_task.take().is_some() {
666678
// Start heartbeat task after connection is established
667-
state.heartbeat_task = self.start_hearbeat_task(state.heartbeat);
679+
state.heartbeat_task =
680+
self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone());
668681
}
669682

670683
drop(state);
@@ -677,14 +690,22 @@ impl Client {
677690
self.tune_notifier.notify_one();
678691
}
679692

680-
fn start_hearbeat_task(&self, heartbeat: u32) -> Option<task::TaskHandle> {
693+
fn start_hearbeat_task(
694+
&self,
695+
heartbeat: u32,
696+
last_received_message: Arc<RwLock<Instant>>,
697+
) -> Option<task::TaskHandle> {
681698
if heartbeat == 0 {
682699
return None;
683700
}
684701
let heartbeat_interval = (heartbeat / 2).max(1);
685702
let channel = self.channel.clone();
686703

704+
let client = self.clone();
705+
687706
let heartbeat_task: task::TaskHandle = tokio::spawn(async move {
707+
let timeout_threashold = u64::from(heartbeat * 4);
708+
688709
loop {
689710
trace!("Sending heartbeat");
690711
if channel
@@ -695,7 +716,20 @@ impl Client {
695716
break;
696717
}
697718
tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await;
719+
720+
let now = Instant::now();
721+
let last_message = last_received_message.read().await;
722+
if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) {
723+
warn!("Heartbeat timeout reached. Force closing connection.");
724+
if !client.is_closed() {
725+
if let Err(e) = client.close().await {
726+
warn!("Error closing client: {}", e);
727+
}
728+
}
729+
break;
730+
}
698731
}
732+
699733
warn!("Heartbeat task stopped. Force closing connection");
700734
})
701735
.into();
@@ -725,17 +759,28 @@ impl Client {
725759
impl MessageHandler for Client {
726760
async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
727761
match &item {
728-
Some(Ok(response)) => match response.kind_ref() {
729-
ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
730-
ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
731-
_ => {
732-
if let Some(handler) = self.state.read().await.handler.as_ref() {
733-
let handler = handler.clone();
734-
735-
tokio::task::spawn(async move { handler.handle_message(item).await });
762+
Some(Ok(response)) => {
763+
// Update last received message time: needed for heartbeat task
764+
{
765+
let s = self.state.read().await;
766+
let mut last_received_message = s.last_received_message.write().await;
767+
*last_received_message = Instant::now();
768+
drop(last_received_message);
769+
drop(s);
770+
}
771+
772+
match response.kind_ref() {
773+
ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await,
774+
ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await,
775+
_ => {
776+
if let Some(handler) = self.state.read().await.handler.as_ref() {
777+
let handler = handler.clone();
778+
779+
tokio::task::spawn(async move { handler.handle_message(item).await });
780+
}
736781
}
737782
}
738-
},
783+
}
739784
Some(Err(err)) => {
740785
trace!(?err);
741786
if let Some(handler) = self.state.read().await.handler.as_ref() {

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pub use crate::consumer::{
9393
Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration, MessageContext,
9494
};
9595
pub use crate::environment::{Environment, EnvironmentBuilder};
96-
pub use crate::producer::{Dedup, NoDedup, OnClosed, Producer, ProducerBuilder};
96+
pub use crate::producer::{ConfirmationStatus, Dedup, NoDedup, OnClosed, Producer, ProducerBuilder};
9797
pub mod types {
9898

9999
pub use crate::byte_capacity::ByteCapacity;

src/producer.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ impl<T> ProducerBuilder<T> {
138138
.create_producer_client(stream, self.client_provided_name.clone())
139139
.await?;
140140

141+
if let Some(heartbeat) = self.overwrite_heartbeat {
142+
client.set_heartbeat(heartbeat).await;
143+
}
144+
141145
let mut publish_version = 1;
142146

143147
if self.filter_value_extractor.is_some() {
@@ -217,10 +221,7 @@ impl<T> ProducerBuilder<T> {
217221
self
218222
}
219223

220-
pub fn overwrite_heartbeat(
221-
mut self,
222-
heartbeat: u32,
223-
) -> ProducerBuilder<T> {
224+
pub fn overwrite_heartbeat(mut self, heartbeat: u32) -> ProducerBuilder<T> {
224225
self.overwrite_heartbeat = Some(heartbeat);
225226
self
226227
}

tests/producer_test.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc, time::Duration};
33
use chrono::Utc;
44
use fake::{Fake, Faker};
55
use futures::{lock::Mutex, StreamExt};
6-
use tokio::{sync::mpsc::channel, task::yield_now, time::sleep};
6+
use tokio::{sync::mpsc::channel, time::sleep};
77

88
use rabbitmq_stream_client::{
99
error::ClientError,
@@ -19,7 +19,6 @@ use common::*;
1919
use rabbitmq_stream_client::types::{
2020
HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy,
2121
};
22-
use tracing::span;
2322

2423
use std::sync::atomic::{AtomicU32, Ordering};
2524
use tokio::sync::Notify;
@@ -860,8 +859,10 @@ async fn producer_timeout() {
860859

861860
sleep(Duration::from_millis(500)).await;
862861

863-
let connection = wait_for_named_connection(client_provided_name.clone()).await;
864-
drop_connection(connection).await;
862+
let is_stopped = tokio::select! {
863+
_ = notifier.notified() => true,
864+
_ = sleep(Duration::from_secs(5)) => false,
865+
};
865866

866-
notifier.notified().await;
867+
assert!(is_stopped, "Producer did not stop after timeout");
867868
}

0 commit comments

Comments
 (0)