Skip to content

RUST-1443 Stop executing server monitors after the server has been closed (2.3 backport) #733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 12, 2022
23 changes: 15 additions & 8 deletions src/runtime/worker_handle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use tokio::sync::mpsc;
use tokio::sync::watch;

/// Handle to a worker. Once all handles have been dropped, the worker
/// will stop waiting for new requests.
#[derive(Debug, Clone)]
pub(crate) struct WorkerHandle {
_sender: mpsc::Sender<()>,
_receiver: watch::Receiver<()>,
}

impl WorkerHandle {
Expand All @@ -18,24 +18,31 @@ impl WorkerHandle {
/// Listener used to determine when all handles have been dropped.
#[derive(Debug)]
pub(crate) struct WorkerHandleListener {
receiver: mpsc::Receiver<()>,
sender: watch::Sender<()>,
}

impl WorkerHandleListener {
/// Listen until all handles are dropped.
/// This will not return until all handles are dropped, so make sure to only poll this via
/// select or with a timeout.
pub(crate) async fn wait_for_all_handle_drops(&mut self) {
self.receiver.recv().await;
pub(crate) async fn wait_for_all_handle_drops(&self) {
self.sender.closed().await
}

/// Returns whether there are handles still alive.
pub(crate) fn is_alive(&self) -> bool {
!self.sender.is_closed()
}

/// Constructs a new channel for for monitoring whether this worker still has references
/// to it.
pub(crate) fn channel() -> (WorkerHandle, WorkerHandleListener) {
let (sender, receiver) = mpsc::channel(1);
let (sender, receiver) = watch::channel(());
(
WorkerHandle { _sender: sender },
WorkerHandleListener { receiver },
WorkerHandle {
_receiver: receiver,
},
WorkerHandleListener { sender },
)
}
}
21 changes: 14 additions & 7 deletions src/sdam/description/topology/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl PartialEq for TopologyDescription {
}

impl TopologyDescription {
pub(crate) fn new(options: ClientOptions) -> crate::error::Result<Self> {
pub(crate) fn initialized(options: &ClientOptions) -> crate::error::Result<Self> {
verify_max_staleness(
options
.selection_criteria
Expand All @@ -146,8 +146,13 @@ impl TopologyDescription {

let servers: HashMap<_, _> = options
.hosts
.into_iter()
.map(|address| (address.clone(), ServerDescription::new(address, None)))
.iter()
.map(|address| {
(
address.clone(),
ServerDescription::new(address.clone(), None),
)
})
.collect();

let session_support_status = if topology_type == TopologyType::LoadBalanced {
Expand All @@ -164,9 +169,9 @@ impl TopologyDescription {
};

Ok(Self {
single_seed: servers.len() == 1,
single_seed: options.hosts.len() == 1,
topology_type,
set_name: options.repl_set_name,
set_name: options.repl_set_name.clone(),
max_set_version: None,
max_election_id: None,
compatibility_error: None,
Expand Down Expand Up @@ -435,11 +440,13 @@ impl TopologyDescription {
_ => None,
});

Some(TopologyDescriptionDiff {
let diff = TopologyDescriptionDiff {
removed_addresses: addresses.difference(&other_addresses).cloned().collect(),
added_addresses: other_addresses.difference(&addresses).cloned().collect(),
changed_servers: changed_servers.collect(),
})
};

Some(diff)
}

/// Syncs the set of servers in the description to those in `hosts`. Servers in the set not
Expand Down
28 changes: 20 additions & 8 deletions src/sdam/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
},
hello::{hello_command, run_hello, HelloReply},
options::{ClientOptions, ServerAddress},
runtime,
runtime::{self, WorkerHandle, WorkerHandleListener},
};

pub(crate) const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
Expand All @@ -36,6 +36,7 @@ pub(crate) struct Monitor {
sdam_event_emitter: Option<SdamEventEmitter>,
update_request_receiver: TopologyCheckRequestReceiver,
client_options: ClientOptions,
server_handle_listener: WorkerHandleListener,
}

impl Monitor {
Expand All @@ -46,8 +47,9 @@ impl Monitor {
sdam_event_emitter: Option<SdamEventEmitter>,
update_request_receiver: TopologyCheckRequestReceiver,
client_options: ClientOptions,
) {
) -> WorkerHandle {
let handshaker = Handshaker::new(Some(client_options.clone().into()));
let (handle, server_handle_listener) = WorkerHandleListener::channel();
let monitor = Self {
address,
client_options,
Expand All @@ -57,8 +59,14 @@ impl Monitor {
sdam_event_emitter,
update_request_receiver,
connection: None,
server_handle_listener,
};
runtime::execute(monitor.execute())
runtime::execute(monitor.execute());
handle
}

fn is_alive(&self) -> bool {
self.server_handle_listener.is_alive()
}

async fn execute(mut self) {
Expand All @@ -67,7 +75,7 @@ impl Monitor {
.heartbeat_freq
.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY);

while self.topology_watcher.is_alive() {
while self.is_alive() {
self.check_server().await;

#[cfg(test)]
Expand All @@ -81,10 +89,14 @@ impl Monitor {
#[cfg(not(test))]
let min_frequency = MIN_HEARTBEAT_FREQUENCY;

runtime::delay_for(min_frequency).await;
self.update_request_receiver
.wait_for_check_request(heartbeat_frequency - min_frequency)
.await;
if self.is_alive() {
runtime::delay_for(min_frequency).await;
}
if self.is_alive() {
self.update_request_receiver
.wait_for_check_request(heartbeat_frequency - min_frequency)
.await;
}
}
}

Expand Down
93 changes: 93 additions & 0 deletions src/sdam/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
collections::HashSet,
sync::Arc,
time::{Duration, Instant},
};
Expand All @@ -8,7 +9,10 @@ use semver::VersionReq;
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};

use crate::{
client::options::{ClientOptions, ServerAddress},
cmap::RawCommandResponse,
error::{Error, ErrorKind},
event::sdam::SdamEventHandler,
hello::{LEGACY_HELLO_COMMAND_NAME, LEGACY_HELLO_COMMAND_NAME_LOWERCASE},
runtime,
test::{
Expand All @@ -28,6 +32,8 @@ use crate::{
Client,
};

use super::{ServerDescription, Topology};

#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn min_heartbeat_frequency() {
Expand Down Expand Up @@ -485,3 +491,90 @@ async fn repl_set_name_mismatch() -> crate::error::Result<()> {

Ok(())
}

/// Test verifying that a server's monitor stops after the server has been removed from the
/// topology.
#[cfg_attr(feature = "tokio-runtime", tokio::test(flavor = "multi_thread"))]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn removed_server_monitor_stops() -> crate::error::Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will need to be forward-ported to the main branch once RUST-360 lands.

let _guard = LOCK.run_concurrently().await;

let handler = Arc::new(EventHandler::new());
let options = ClientOptions::builder()
.hosts(vec![
ServerAddress::parse("localhost:49152")?,
ServerAddress::parse("localhost:49153")?,
ServerAddress::parse("localhost:49154")?,
])
.heartbeat_freq(Duration::from_millis(50))
.sdam_event_handler(handler.clone() as Arc<dyn SdamEventHandler>)
.repl_set_name("foo".to_string())
.build();

let hosts = options.hosts.clone();
let set_name = options.repl_set_name.clone().unwrap();

let mut subscriber = handler.subscribe();
let topology = Topology::new(options)?;

// Wait until all three monitors have started.
let mut seen_monitors = HashSet::new();
subscriber
.wait_for_event(Duration::from_millis(500), |event| {
if let Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) = event {
seen_monitors.insert(e.server_address.clone());
}
seen_monitors.len() == hosts.len()
})
.await
.expect("should see all three monitors start");

// Remove the third host from the topology.
let hello = doc! {
"ok": 1,
"isWritablePrimary": true,
"hosts": [
hosts[0].clone().to_string(),
hosts[1].clone().to_string(),
],
"me": hosts[0].clone().to_string(),
"setName": set_name,
"maxBsonObjectSize": 1234,
"maxWriteBatchSize": 1234,
"maxMessageSizeBytes": 1234,
"minWireVersion": 0,
"maxWireVersion": 13,
};
let hello_reply = Some(Ok(RawCommandResponse::with_document_and_address(
hosts[0].clone(),
hello,
)
.unwrap()
.into_hello_reply(Duration::from_millis(10))
.unwrap()));

topology
.clone_updater()
.update(ServerDescription::new(hosts[0].clone(), hello_reply))
.await;

subscriber.wait_for_event(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerClosed(e)) if e.address == hosts[2])
}).await.expect("should see server closed event");

// Capture heartbeat events for 1 second. The monitor for the removed server should stop
// publishing them.
let events = subscriber.collect_events(Duration::from_secs(1), |event| {
matches!(event, Event::Sdam(SdamEvent::ServerHeartbeatStarted(e)) if e.server_address == hosts[2])
}).await;

// Use 3 to account for any heartbeats that happen to start between emitting the ServerClosed
// event and actually publishing the state with the closed server.
assert!(
events.len() < 3,
"expected monitor for removed server to stop performing checks, but saw {} heartbeats",
events.len()
);

Ok(())
}
Loading