From ed072711419a9b9b6089c16154c48169f979c346 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Thu, 29 May 2025 12:51:52 +0200 Subject: [PATCH 1/3] Handle on_closed callback --- src/client/dispatcher.rs | 32 +++++++++--------- src/client/mod.rs | 72 ++++++++++++++++++++-------------------- src/environment.rs | 1 + src/lib.rs | 2 +- src/producer.rs | 17 ++++++++++ tests/producer_test.rs | 42 ++++++++++++++++++++++- 6 files changed, 112 insertions(+), 54 deletions(-) diff --git a/src/client/dispatcher.rs b/src/client/dispatcher.rs index 1abcf43b..1f9fa6d5 100644 --- a/src/client/dispatcher.rs +++ b/src/client/dispatcher.rs @@ -19,22 +19,6 @@ use super::{channel::ChannelReceiver, handler::MessageHandler}; #[derive(Clone)] pub(crate) struct Dispatcher(DispatcherState); -pub(crate) struct DispatcherState { - requests: Arc, - correlation_id: Arc, - handler: Arc>>, -} - -impl Clone for DispatcherState { - fn clone(&self) -> Self { - DispatcherState { - requests: self.requests.clone(), - correlation_id: self.correlation_id.clone(), - handler: self.handler.clone(), - } - } -} - struct RequestsMap { requests: DashMap>, closed: AtomicBool, @@ -126,6 +110,22 @@ where } } +pub(crate) struct DispatcherState { + requests: Arc, + correlation_id: Arc, + handler: Arc>>, +} + +impl Clone for DispatcherState { + fn clone(&self) -> Self { + DispatcherState { + requests: self.requests.clone(), + correlation_id: self.correlation_id.clone(), + handler: self.handler.clone(), + } + } +} + impl DispatcherState where T: MessageHandler, diff --git a/src/client/mod.rs b/src/client/mod.rs index 15c42242..f7faba58 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -144,42 +144,6 @@ pub struct ClientState { heartbeat_task: Option, } -#[async_trait::async_trait] -impl MessageHandler for Client { - async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> { - match &item { - Some(Ok(response)) => match response.kind_ref() { - ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, - ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, - _ => { - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - - tokio::task::spawn(async move { handler.handle_message(item).await }); - } - } - }, - Some(Err(err)) => { - trace!(?err); - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - - tokio::task::spawn(async move { handler.handle_message(item).await }); - } - } - None => { - trace!("Closing client"); - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - tokio::task::spawn(async move { handler.handle_message(None).await }); - } - } - } - - Ok(()) - } -} - /// Raw API for taking to RabbitMQ stream /// /// For high level APIs check [`crate::Environment`] @@ -751,3 +715,39 @@ impl Client { .await } } + +#[async_trait::async_trait] +impl MessageHandler for Client { + async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> { + match &item { + Some(Ok(response)) => match response.kind_ref() { + ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, + ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, + _ => { + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + + tokio::task::spawn(async move { handler.handle_message(item).await }); + } + } + }, + Some(Err(err)) => { + trace!(?err); + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + + tokio::task::spawn(async move { handler.handle_message(item).await }); + } + } + None => { + trace!("Closing client"); + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + tokio::task::spawn(async move { handler.handle_message(None).await }); + } + } + } + + Ok(()) + } +} diff --git a/src/environment.rs b/src/environment.rs index 9592dbcd..ba8682c8 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -199,6 +199,7 @@ impl Environment { data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), + on_closed: None, } } diff --git a/src/lib.rs b/src/lib.rs index d6633b00..f0fe4017 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ pub use crate::consumer::{ Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration, MessageContext, }; pub use crate::environment::{Environment, EnvironmentBuilder}; -pub use crate::producer::{Dedup, NoDedup, Producer, ProducerBuilder}; +pub use crate::producer::{Dedup, NoDedup, OnClosed, Producer, ProducerBuilder}; pub mod types { pub use crate::byte_capacity::ByteCapacity; diff --git a/src/producer.rs b/src/producer.rs index 9fa3b93e..5d604f8a 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -116,6 +116,7 @@ pub struct ProducerBuilder { pub(crate) data: PhantomData, pub filter_value_extractor: Option, pub(crate) client_provided_name: String, + pub(crate) on_closed: Option>, } #[derive(Clone)] @@ -151,6 +152,7 @@ impl ProducerBuilder { let confirm_handler = ProducerConfirmHandler { waiting_confirmations: waiting_confirmations.clone(), metrics_collector, + on_closed: self.on_closed, }; client.set_handler(confirm_handler).await; @@ -204,6 +206,11 @@ impl ProducerBuilder { } } + pub fn on_closed(mut self, on_closed: Arc) -> ProducerBuilder { + self.on_closed = Some(on_closed); + self + } + pub fn batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; self @@ -223,6 +230,7 @@ impl ProducerBuilder { data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), + on_closed: self.on_closed, } } @@ -505,9 +513,15 @@ impl Producer { } } +#[async_trait::async_trait] +pub trait OnClosed { + async fn on_closed(&self); +} + struct ProducerConfirmHandler { waiting_confirmations: WaiterMap, metrics_collector: Arc, + on_closed: Option>, } #[async_trait::async_trait] @@ -583,6 +597,9 @@ impl MessageHandler for ProducerConfirmHandler { } None => { trace!("Connection closed"); + if let Some(on_close) = &self.on_closed { + on_close.on_closed().await; + } // TODO connection close clean all waiting } } diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 0e9c1858..0b01a15d 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -8,7 +8,7 @@ use tokio::{sync::mpsc::channel, task::yield_now, time::sleep}; use rabbitmq_stream_client::{ error::ClientError, types::{Message, OffsetSpecification, SimpleValue}, - Environment, + Environment, OnClosed, }; #[path = "./common.rs"] @@ -784,3 +784,43 @@ async fn producer_drop() { let metrics = tokio::runtime::Handle::current().metrics(); assert_eq!(metrics.num_alive_tasks(), 0); } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_drop_connection_on_close() { + struct Foo { + notifier: Arc, + } + #[async_trait::async_trait] + impl OnClosed for Foo { + async fn on_closed(&self) { + self.notifier.notify_one(); + } + } + + let notifier = Arc::new(Notify::new()); + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .on_closed(Arc::new(Foo { + notifier: notifier.clone(), + })) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + sleep(Duration::from_millis(500)).await; + + let connection = wait_for_named_connection(client_provided_name.clone()).await; + drop_connection(connection).await; + + notifier.notified().await; +} From 7eef6446e6c31c636bc74f587d52972587c50e90 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Sat, 31 May 2025 10:48:04 +0200 Subject: [PATCH 2/3] Little improvement --- examples/ha_producer.rs | 92 +++++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 21 ++++++---- src/environment.rs | 1 + src/producer.rs | 10 +++++ tests/producer_test.rs | 41 ++++++++++++++++++ 5 files changed, 157 insertions(+), 8 deletions(-) create mode 100644 examples/ha_producer.rs diff --git a/examples/ha_producer.rs b/examples/ha_producer.rs new file mode 100644 index 00000000..a234246b --- /dev/null +++ b/examples/ha_producer.rs @@ -0,0 +1,92 @@ +use rabbitmq_stream_client::error::{StreamCreateError, ProducerPublishError}; +use rabbitmq_stream_client::types::{ByteCapacity, Message, ResponseCode}; +use rabbitmq_stream_client::{ConfirmationStatus, NoDedup, OnClosed, Producer, RabbitMQStreamResult}; +use tokio::sync::RwLock; +use rabbitmq_stream_client::Environment; + +struct MyHAProducer { + environment: Environment, + stream: String, + producer: RwLock> +} + +#[async_trait::async_trait] +impl OnClosed for MyHAProducer { + async fn on_closed(&self) { + let mut producer = self.producer.write().await; + + let new_producer = self.environment + .producer() + .build(&self.stream).await + .unwrap(); + + *producer = new_producer; + } +} + +impl MyHAProducer { + async fn new(environment: Environment, stream: &str) -> RabbitMQStreamResult { + ensure_stream_exists(&environment, stream).await?; + + let producer = environment + .producer() + .build(stream).await + .unwrap(); + + Ok(Self { + environment, + stream: stream.to_string(), + producer: RwLock::new(producer), + }) + } + + async fn send_with_confirm(&self, message: Message) -> Result { + let producer = self.producer.read().await; + producer.send_with_confirm(message).await + } +} + +async fn ensure_stream_exists(environment: &Environment, stream: &str) -> RabbitMQStreamResult<()> { + let create_response = environment + .stream_creator() + .max_length(ByteCapacity::GB(5)) + .create(stream) + .await; + + if let Err(e) = create_response { + if let StreamCreateError::Create { stream, status } = e { + match status { + // we can ignore this error because the stream already exists + ResponseCode::StreamAlreadyExists => {} + err => { + panic!("Error creating stream: {:?} {:?}", stream, err); + } + } + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let environment = Environment::builder().build().await?; + let stream = "hello-rust-stream"; + + let producer = MyHAProducer::new(environment, stream).await?; + + producer.send_with_confirm(Message::builder().body("Hello, world!").build()).await?; + + /* + let number_of_messages = 1000000; + for i in 0..number_of_messages { + let msg = Message::builder() + .body(format!("stream message_{}", i)) + .build(); + producer.send_with_confirm(msg).await?; + } + producer.close().await?; + */ + + Ok(()) +} diff --git a/src/client/mod.rs b/src/client/mod.rs index f7faba58..2620f3e2 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -522,7 +522,9 @@ impl Client { .await?; // Start heartbeat task after connection is established - self.start_hearbeat_task(self.state.write().await.deref_mut()); + let mut state = self.state.write().await; + state.heartbeat_task = self.start_hearbeat_task(state.heartbeat); + drop(state); Ok(()) } @@ -661,7 +663,8 @@ impl Client { ); if state.heartbeat_task.take().is_some() { - self.start_hearbeat_task(&mut state); + // Start heartbeat task after connection is established + state.heartbeat_task = self.start_hearbeat_task(state.heartbeat); } drop(state); @@ -674,13 +677,14 @@ impl Client { self.tune_notifier.notify_one(); } - fn start_hearbeat_task(&self, state: &mut ClientState) { - if state.heartbeat == 0 { - return; + fn start_hearbeat_task(&self, heartbeat: u32) -> Option { + if heartbeat == 0 { + return None; } - let heartbeat_interval = (state.heartbeat / 2).max(1); + let heartbeat_interval = (heartbeat / 2).max(1); let channel = self.channel.clone(); - let heartbeat_task = tokio::spawn(async move { + + let heartbeat_task: task::TaskHandle = tokio::spawn(async move { loop { trace!("Sending heartbeat"); if channel @@ -695,7 +699,8 @@ impl Client { warn!("Heartbeat task stopped. Force closing connection"); }) .into(); - state.heartbeat_task = Some(heartbeat_task); + + Some(heartbeat_task) } async fn handle_heart_beat_command(&self) { diff --git a/src/environment.rs b/src/environment.rs index ba8682c8..40e2793e 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -200,6 +200,7 @@ impl Environment { filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), on_closed: None, + overwrite_heartbeat: None, } } diff --git a/src/producer.rs b/src/producer.rs index 5d604f8a..9c98ec80 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -117,6 +117,7 @@ pub struct ProducerBuilder { pub filter_value_extractor: Option, pub(crate) client_provided_name: String, pub(crate) on_closed: Option>, + pub(crate) overwrite_heartbeat: Option, } #[derive(Clone)] @@ -216,6 +217,14 @@ impl ProducerBuilder { self } + pub fn overwrite_heartbeat( + mut self, + heartbeat: u32, + ) -> ProducerBuilder { + self.overwrite_heartbeat = Some(heartbeat); + self + } + pub fn client_provided_name(mut self, name: &str) -> Self { self.client_provided_name = String::from(name); self @@ -231,6 +240,7 @@ impl ProducerBuilder { filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), on_closed: self.on_closed, + overwrite_heartbeat: None, } } diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 0b01a15d..7458a068 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -824,3 +824,44 @@ async fn producer_drop_connection_on_close() { notifier.notified().await; } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_timeout() { + struct Foo { + notifier: Arc, + } + #[async_trait::async_trait] + impl OnClosed for Foo { + async fn on_closed(&self) { + self.notifier.notify_one(); + } + } + + let notifier = Arc::new(Notify::new()); + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .overwrite_heartbeat(1) + .on_closed(Arc::new(Foo { + notifier: notifier.clone(), + })) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + sleep(Duration::from_millis(500)).await; + + let connection = wait_for_named_connection(client_provided_name.clone()).await; + drop_connection(connection).await; + + notifier.notified().await; +} From af13e7cfeb5c4680d97edbfb3aa65d2688aa2208 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Sat, 31 May 2025 11:30:47 +0200 Subject: [PATCH 3/3] Implement heartbeat timeout --- src/client/mod.rs | 73 ++++++++++++++++++++++++++++++++++-------- src/lib.rs | 2 +- src/producer.rs | 9 +++--- tests/producer_test.rs | 11 ++++--- 4 files changed, 71 insertions(+), 24 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 2620f3e2..714f8553 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,4 +1,3 @@ -use std::ops::DerefMut; use std::{ collections::HashMap, io, @@ -142,6 +141,7 @@ pub struct ClientState { max_frame_size: u32, last_heatbeat: Instant, heartbeat_task: Option, + last_received_message: Arc>, } /// Raw API for taking to RabbitMQ stream @@ -165,8 +165,9 @@ impl Client { let (sender, receiver) = Client::create_connection(&broker).await?; - let dispatcher = Dispatcher::new(); + let last_received_message = Arc::new(RwLock::new(Instant::now())); + let dispatcher = Dispatcher::new(); let state = ClientState { server_properties: HashMap::new(), connection_properties: HashMap::new(), @@ -175,6 +176,7 @@ impl Client { max_frame_size: broker.max_frame_size, last_heatbeat: Instant::now(), heartbeat_task: None, + last_received_message: last_received_message.clone(), }; let mut client = Client { dispatcher, @@ -483,6 +485,14 @@ impl Client { self.filtering_supported } + pub async fn set_heartbeat(&self, heartbeat: u32) { + let mut state = self.state.write().await; + state.heartbeat = heartbeat; + // Eventually, this drops the previous heartbeat task + state.heartbeat_task = + self.start_hearbeat_task(heartbeat, state.last_received_message.clone()); + } + async fn create_connection( broker: &ClientOptions, ) -> Result< @@ -500,6 +510,7 @@ impl Client { Ok((tx, rx)) } + async fn initialize(&mut self, receiver: ChannelReceiver) -> Result<(), ClientError> where T: Stream> + Unpin + Send, @@ -523,7 +534,8 @@ impl Client { // Start heartbeat task after connection is established let mut state = self.state.write().await; - state.heartbeat_task = self.start_hearbeat_task(state.heartbeat); + state.heartbeat_task = + self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone()); drop(state); Ok(()) @@ -664,7 +676,8 @@ impl Client { if state.heartbeat_task.take().is_some() { // Start heartbeat task after connection is established - state.heartbeat_task = self.start_hearbeat_task(state.heartbeat); + state.heartbeat_task = + self.start_hearbeat_task(state.heartbeat, state.last_received_message.clone()); } drop(state); @@ -677,14 +690,22 @@ impl Client { self.tune_notifier.notify_one(); } - fn start_hearbeat_task(&self, heartbeat: u32) -> Option { + fn start_hearbeat_task( + &self, + heartbeat: u32, + last_received_message: Arc>, + ) -> Option { if heartbeat == 0 { return None; } let heartbeat_interval = (heartbeat / 2).max(1); let channel = self.channel.clone(); + let client = self.clone(); + let heartbeat_task: task::TaskHandle = tokio::spawn(async move { + let timeout_threashold = u64::from(heartbeat * 4); + loop { trace!("Sending heartbeat"); if channel @@ -695,7 +716,20 @@ impl Client { break; } tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await; + + let now = Instant::now(); + let last_message = last_received_message.read().await; + if now.duration_since(*last_message) >= Duration::from_secs(timeout_threashold) { + warn!("Heartbeat timeout reached. Force closing connection."); + if !client.is_closed() { + if let Err(e) = client.close().await { + warn!("Error closing client: {}", e); + } + } + break; + } } + warn!("Heartbeat task stopped. Force closing connection"); }) .into(); @@ -725,17 +759,28 @@ impl Client { impl MessageHandler for Client { async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> { match &item { - Some(Ok(response)) => match response.kind_ref() { - ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, - ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, - _ => { - if let Some(handler) = self.state.read().await.handler.as_ref() { - let handler = handler.clone(); - - tokio::task::spawn(async move { handler.handle_message(item).await }); + Some(Ok(response)) => { + // Update last received message time: needed for heartbeat task + { + let s = self.state.read().await; + let mut last_received_message = s.last_received_message.write().await; + *last_received_message = Instant::now(); + drop(last_received_message); + drop(s); + } + + match response.kind_ref() { + ResponseKind::Tunes(tune) => self.handle_tune_command(tune).await, + ResponseKind::Heartbeat(_) => self.handle_heart_beat_command().await, + _ => { + if let Some(handler) = self.state.read().await.handler.as_ref() { + let handler = handler.clone(); + + tokio::task::spawn(async move { handler.handle_message(item).await }); + } } } - }, + } Some(Err(err)) => { trace!(?err); if let Some(handler) = self.state.read().await.handler.as_ref() { diff --git a/src/lib.rs b/src/lib.rs index f0fe4017..aaeecd47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ pub use crate::consumer::{ Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration, MessageContext, }; pub use crate::environment::{Environment, EnvironmentBuilder}; -pub use crate::producer::{Dedup, NoDedup, OnClosed, Producer, ProducerBuilder}; +pub use crate::producer::{ConfirmationStatus, Dedup, NoDedup, OnClosed, Producer, ProducerBuilder}; pub mod types { pub use crate::byte_capacity::ByteCapacity; diff --git a/src/producer.rs b/src/producer.rs index 9c98ec80..03544395 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -138,6 +138,10 @@ impl ProducerBuilder { .create_producer_client(stream, self.client_provided_name.clone()) .await?; + if let Some(heartbeat) = self.overwrite_heartbeat { + client.set_heartbeat(heartbeat).await; + } + let mut publish_version = 1; if self.filter_value_extractor.is_some() { @@ -217,10 +221,7 @@ impl ProducerBuilder { self } - pub fn overwrite_heartbeat( - mut self, - heartbeat: u32, - ) -> ProducerBuilder { + pub fn overwrite_heartbeat(mut self, heartbeat: u32) -> ProducerBuilder { self.overwrite_heartbeat = Some(heartbeat); self } diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 7458a068..7313d1ec 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use chrono::Utc; use fake::{Fake, Faker}; use futures::{lock::Mutex, StreamExt}; -use tokio::{sync::mpsc::channel, task::yield_now, time::sleep}; +use tokio::{sync::mpsc::channel, time::sleep}; use rabbitmq_stream_client::{ error::ClientError, @@ -19,7 +19,6 @@ use common::*; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, }; -use tracing::span; use std::sync::atomic::{AtomicU32, Ordering}; use tokio::sync::Notify; @@ -860,8 +859,10 @@ async fn producer_timeout() { sleep(Duration::from_millis(500)).await; - let connection = wait_for_named_connection(client_provided_name.clone()).await; - drop_connection(connection).await; + let is_stopped = tokio::select! { + _ = notifier.notified() => true, + _ = sleep(Duration::from_secs(5)) => false, + }; - notifier.notified().await; + assert!(is_stopped, "Producer did not stop after timeout"); }