Skip to content

RUST-933 Add support for the srvMaxHosts option #977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[profile.default]
retries = 1
test-threads = 1

[profile.ci]
Expand Down
52 changes: 52 additions & 0 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,10 @@ pub struct ClientOptions {
#[builder(default)]
pub write_concern: Option<WriteConcern>,

/// Limit on the number of mongos connections that may be created for sharded topologies.
#[builder(default)]
pub srv_max_hosts: Option<u32>,

/// Information from the SRV URI that generated these client options, if applicable.
#[builder(default, setter(skip))]
#[serde(skip)]
Expand Down Expand Up @@ -708,6 +712,8 @@ impl Serialize for ClientOptions {
zlibcompressionlevel: &'a Option<i32>,

loadbalanced: &'a Option<bool>,

srvmaxhosts: Option<i32>,
}

let client_options = ClientOptionsHelper {
Expand All @@ -732,6 +738,7 @@ impl Serialize for ClientOptions {
writeconcern: &self.write_concern,
loadbalanced: &self.load_balanced,
zlibcompressionlevel: &None,
srvmaxhosts: self.srv_max_hosts.map(|v| v as i32),
};

client_options.serialize(serializer)
Expand Down Expand Up @@ -875,6 +882,9 @@ pub struct ConnectionString {
/// [`Binary::to_uuid_with_representation`](bson::binary::Binary::to_uuid_with_representation).
pub uuid_representation: Option<UuidRepresentation>,

/// Limit on the number of mongos connections that may be created for sharded topologies.
pub srv_max_hosts: Option<u32>,

wait_queue_timeout: Option<Duration>,
tls_insecure: Option<bool>,

Expand Down Expand Up @@ -1252,6 +1262,24 @@ impl ClientOptions {
options.load_balanced = config.load_balanced;
}

if let Some(max) = options.srv_max_hosts {
if max > 0 {
if options.repl_set_name.is_some() {
return Err(Error::invalid_argument(
"srvMaxHosts and replicaSet cannot both be present",
));
}
if options.load_balanced == Some(true) {
return Err(Error::invalid_argument(
"srvMaxHosts and loadBalanced=true cannot both be present",
));
}
config.hosts = crate::sdam::choose_n(&config.hosts, max as usize)
.cloned()
.collect();
}
}

// Set the ClientOptions hosts to those found during the SRV lookup.
config.hosts
}
Expand Down Expand Up @@ -1338,6 +1366,7 @@ impl ClientOptions {
test_options: None,
#[cfg(feature = "tracing-unstable")]
tracing_max_document_length_bytes: None,
srv_max_hosts: conn_str.srv_max_hosts,
}
}

Expand Down Expand Up @@ -1721,6 +1750,26 @@ impl ConnectionString {
ConnectionStringParts::default()
};

if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, the spec says that the validation here has to happen both for the initial URI parse and after SRV resolution, so that gets duplicated too :/

if !srv {
return Err(Error::invalid_argument(
"srvMaxHosts cannot be specified with a non-SRV URI",
));
}
if srv_max_hosts > 0 {
if conn_str.replica_set.is_some() {
return Err(Error::invalid_argument(
"srvMaxHosts and replicaSet cannot both be present",
));
}
if conn_str.load_balanced == Some(true) {
return Err(Error::invalid_argument(
"srvMaxHosts and loadBalanced=true cannot both be present",
));
}
}
}

// Set username and password.
if let Some(u) = username {
let credential = conn_str.credential.get_or_insert_with(Default::default);
Expand Down Expand Up @@ -2147,6 +2196,9 @@ impl ConnectionString {
k @ "sockettimeoutms" => {
self.socket_timeout = Some(Duration::from_millis(get_duration!(value, k)));
}
k @ "srvmaxhosts" => {
self.srv_max_hosts = Some(get_u32!(value, k));
}
k @ "tls" | k @ "ssl" => {
let tls = get_bool!(value, k);

Expand Down
2 changes: 2 additions & 0 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ async fn run_test(test_file: TestFile) {
)
// The Rust driver disallows `maxPoolSize=0`.
|| test_case.description.contains("maxPoolSize=0 does not error")
// TODO RUST-933 implement custom srvServiceName support
|| test_case.description.contains("custom srvServiceName")
{
continue;
}
Expand Down
1 change: 1 addition & 0 deletions src/sdam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) use self::{
description::{
server::{ServerDescription, TopologyVersion},
topology::{
choose_n,
server_selection::{self, SelectedServer},
verify_max_staleness,
TopologyDescription,
Expand Down
28 changes: 26 additions & 2 deletions src/sdam/description/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ pub(crate) struct TopologyDescription {

/// The server descriptions of each member of the topology.
pub(crate) servers: HashMap<ServerAddress, ServerDescription>,

/// The maximum number of hosts.
pub(crate) srv_max_hosts: Option<u32>,
}

impl PartialEq for TopologyDescription {
Expand Down Expand Up @@ -141,6 +144,7 @@ impl Default for TopologyDescription {
local_threshold: Default::default(),
heartbeat_freq: Default::default(),
servers: Default::default(),
srv_max_hosts: Default::default(),
}
}
}
Expand Down Expand Up @@ -177,6 +181,7 @@ impl TopologyDescription {
self.set_name = options.repl_set_name.clone();
self.local_threshold = options.local_threshold;
self.heartbeat_freq = options.heartbeat_freq;
self.srv_max_hosts = options.srv_max_hosts;
}

/// Gets the topology type of the cluster.
Expand Down Expand Up @@ -381,9 +386,23 @@ impl TopologyDescription {
/// Syncs the set of servers in the description to those in `hosts`. Servers in the set not
/// already present in the cluster will be added, and servers in the cluster not present in the
/// set will be removed.
pub(crate) fn sync_hosts(&mut self, hosts: &HashSet<ServerAddress>) {
self.add_new_servers_from_addresses(hosts.iter());
pub(crate) fn sync_hosts(&mut self, hosts: HashSet<ServerAddress>) {
self.servers.retain(|host, _| hosts.contains(host));
let mut new = vec![];
for host in hosts {
if !self.servers.contains_key(&host) {
new.push((host.clone(), ServerDescription::new(host)));
}
}
if let Some(max) = self.srv_max_hosts {
let max = max as usize;
if max > 0 && max < self.servers.len() + new.len() {
new = choose_n(&new, max.saturating_sub(self.servers.len()))
.cloned()
.collect();
}
}
self.servers.extend(new);
}

pub(crate) fn transaction_support_status(&self) -> TransactionSupportStatus {
Expand Down Expand Up @@ -730,6 +749,11 @@ impl TopologyDescription {
}
}

pub(crate) fn choose_n<T>(values: &[T], n: usize) -> impl Iterator<Item = &T> {
use rand::{prelude::SliceRandom, SeedableRng};
values.choose_multiple(&mut rand::rngs::SmallRng::from_entropy(), n)
}

/// Enum representing whether transactions are supported by the topology.
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum TransactionSupportStatus {
Expand Down
6 changes: 1 addition & 5 deletions src/sdam/description/topology/server_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ mod test;

use std::{collections::HashMap, fmt, ops::Deref, sync::Arc, time::Duration};

use rand::{rngs::SmallRng, seq::SliceRandom, SeedableRng};

use super::TopologyDescription;
use crate::{
error::{ErrorKind, Result},
Expand Down Expand Up @@ -86,9 +84,7 @@ fn select_server_in_latency_window(in_window: Vec<&Arc<Server>>) -> Option<Arc<S
return Some(in_window[0].clone());
}

let mut rng = SmallRng::from_entropy();
in_window
.choose_multiple(&mut rng, 2)
super::choose_n(&in_window, 2)
.min_by_key(|s| s.operation_count())
.map(|server| (*server).clone())
}
Expand Down
1 change: 1 addition & 0 deletions src/sdam/description/topology/server_selection/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl TestTopologyDescription {
local_threshold: None,
heartbeat_freq: heartbeat_frequency,
servers,
srv_max_hosts: None,
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/sdam/srv_polling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl SrvPollingMonitor {

// TODO: RUST-230 Log error with host that was returned.
self.topology_updater
.sync_hosts(lookup.hosts.into_iter().filter_map(Result::ok).collect())
.sync_hosts(lookup.hosts.into_iter().collect())
.await;
}

Expand All @@ -120,7 +120,9 @@ impl SrvPollingMonitor {
}
let initial_hostname = self.initial_hostname.clone();
let resolver = self.get_or_create_srv_resolver().await?;
resolver.get_srv_hosts(initial_hostname.as_str()).await
resolver
.get_srv_hosts(initial_hostname.as_str(), crate::srv::DomainMismatch::Skip)
.await
}

async fn get_or_create_srv_resolver(&mut self) -> Result<&SrvResolver> {
Expand Down
64 changes: 62 additions & 2 deletions src/sdam/srv_polling/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,26 @@ lazy_static::lazy_static! {
}

async fn run_test(new_hosts: Result<Vec<ServerAddress>>, expected_hosts: HashSet<ServerAddress>) {
run_test_srv(None, new_hosts, expected_hosts).await
}

async fn run_test_srv(
max_hosts: Option<u32>,
new_hosts: Result<Vec<ServerAddress>>,
expected_hosts: HashSet<ServerAddress>,
) {
let actual = run_test_extra(max_hosts, new_hosts).await;
assert_eq!(expected_hosts, actual);
}

async fn run_test_extra(
max_hosts: Option<u32>,
new_hosts: Result<Vec<ServerAddress>>,
) -> HashSet<ServerAddress> {
let mut options = ClientOptions::new_srv();
options.hosts = DEFAULT_HOSTS.clone();
options.test_options_mut().disable_monitoring_threads = true;
options.srv_max_hosts = max_hosts;
let mut topology = Topology::new(options.clone()).unwrap();
topology.watch().wait_until_initialized().await;
let mut monitor =
Expand All @@ -38,12 +55,12 @@ async fn run_test(new_hosts: Result<Vec<ServerAddress>>, expected_hosts: HashSet
.update_hosts(new_hosts.and_then(make_lookup_hosts))
.await;

assert_eq!(expected_hosts, topology.server_addresses());
topology.server_addresses()
}

fn make_lookup_hosts(hosts: Vec<ServerAddress>) -> Result<LookupHosts> {
Ok(LookupHosts {
hosts: hosts.into_iter().map(Result::Ok).collect(),
hosts,
min_ttl: Duration::from_secs(60),
})
}
Expand Down Expand Up @@ -136,3 +153,46 @@ async fn load_balanced_no_srv_polling() {
topology.server_addresses()
);
}

// SRV polling with srvMaxHosts MongoClient option: All DNS records are selected (srvMaxHosts = 0)
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn srv_max_hosts_zero() {
let hosts = vec![
localhost_test_build_10gen(27017),
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];

run_test_srv(None, Ok(hosts.clone()), hosts.clone().into_iter().collect()).await;
run_test_srv(Some(0), Ok(hosts.clone()), hosts.into_iter().collect()).await;
}

// SRV polling with srvMaxHosts MongoClient option: All DNS records are selected (srvMaxHosts >=
// records)
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn srv_max_hosts_gt_actual() {
let hosts = vec![
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];

run_test_srv(Some(2), Ok(hosts.clone()), hosts.into_iter().collect()).await;
}

// SRV polling with srvMaxHosts MongoClient option: New DNS records are randomly selected
// (srvMaxHosts > 0)
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn srv_max_hosts_random() {
let hosts = vec![
localhost_test_build_10gen(27017),
localhost_test_build_10gen(27019),
localhost_test_build_10gen(27020),
];

let actual = run_test_extra(Some(2), Ok(hosts)).await;
assert_eq!(2, actual.len());
assert!(actual.contains(&localhost_test_build_10gen(27017)));
}
2 changes: 1 addition & 1 deletion src/sdam/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ impl TopologyWorker {

async fn sync_hosts(&mut self, hosts: HashSet<ServerAddress>) -> bool {
let mut new_description = self.topology_description.clone();
new_description.sync_hosts(&hosts);
new_description.sync_hosts(hosts);
self.update_topology(new_description).await
}

Expand Down
Loading