Skip to content

Commit bb8df38

Browse files
authored
RUST-933 Add support for the srvMaxHosts option (#977)
1 parent 4f7402d commit bb8df38

File tree

15 files changed

+395
-34
lines changed

15 files changed

+395
-34
lines changed

.config/nextest.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[profile.default]
2-
retries = 1
32
test-threads = 1
43

54
[profile.ci]

src/client/options.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,10 @@ pub struct ClientOptions {
586586
#[builder(default)]
587587
pub write_concern: Option<WriteConcern>,
588588

589+
/// Limit on the number of mongos connections that may be created for sharded topologies.
590+
#[builder(default)]
591+
pub srv_max_hosts: Option<u32>,
592+
589593
/// Information from the SRV URI that generated these client options, if applicable.
590594
#[builder(default, setter(skip))]
591595
#[serde(skip)]
@@ -708,6 +712,8 @@ impl Serialize for ClientOptions {
708712
zlibcompressionlevel: &'a Option<i32>,
709713

710714
loadbalanced: &'a Option<bool>,
715+
716+
srvmaxhosts: Option<i32>,
711717
}
712718

713719
let client_options = ClientOptionsHelper {
@@ -732,6 +738,7 @@ impl Serialize for ClientOptions {
732738
writeconcern: &self.write_concern,
733739
loadbalanced: &self.load_balanced,
734740
zlibcompressionlevel: &None,
741+
srvmaxhosts: self.srv_max_hosts.map(|v| v as i32),
735742
};
736743

737744
client_options.serialize(serializer)
@@ -875,6 +882,9 @@ pub struct ConnectionString {
875882
/// [`Binary::to_uuid_with_representation`](bson::binary::Binary::to_uuid_with_representation).
876883
pub uuid_representation: Option<UuidRepresentation>,
877884

885+
/// Limit on the number of mongos connections that may be created for sharded topologies.
886+
pub srv_max_hosts: Option<u32>,
887+
878888
wait_queue_timeout: Option<Duration>,
879889
tls_insecure: Option<bool>,
880890

@@ -1252,6 +1262,24 @@ impl ClientOptions {
12521262
options.load_balanced = config.load_balanced;
12531263
}
12541264

1265+
if let Some(max) = options.srv_max_hosts {
1266+
if max > 0 {
1267+
if options.repl_set_name.is_some() {
1268+
return Err(Error::invalid_argument(
1269+
"srvMaxHosts and replicaSet cannot both be present",
1270+
));
1271+
}
1272+
if options.load_balanced == Some(true) {
1273+
return Err(Error::invalid_argument(
1274+
"srvMaxHosts and loadBalanced=true cannot both be present",
1275+
));
1276+
}
1277+
config.hosts = crate::sdam::choose_n(&config.hosts, max as usize)
1278+
.cloned()
1279+
.collect();
1280+
}
1281+
}
1282+
12551283
// Set the ClientOptions hosts to those found during the SRV lookup.
12561284
config.hosts
12571285
}
@@ -1338,6 +1366,7 @@ impl ClientOptions {
13381366
test_options: None,
13391367
#[cfg(feature = "tracing-unstable")]
13401368
tracing_max_document_length_bytes: None,
1369+
srv_max_hosts: conn_str.srv_max_hosts,
13411370
}
13421371
}
13431372

@@ -1721,6 +1750,26 @@ impl ConnectionString {
17211750
ConnectionStringParts::default()
17221751
};
17231752

1753+
if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
1754+
if !srv {
1755+
return Err(Error::invalid_argument(
1756+
"srvMaxHosts cannot be specified with a non-SRV URI",
1757+
));
1758+
}
1759+
if srv_max_hosts > 0 {
1760+
if conn_str.replica_set.is_some() {
1761+
return Err(Error::invalid_argument(
1762+
"srvMaxHosts and replicaSet cannot both be present",
1763+
));
1764+
}
1765+
if conn_str.load_balanced == Some(true) {
1766+
return Err(Error::invalid_argument(
1767+
"srvMaxHosts and loadBalanced=true cannot both be present",
1768+
));
1769+
}
1770+
}
1771+
}
1772+
17241773
// Set username and password.
17251774
if let Some(u) = username {
17261775
let credential = conn_str.credential.get_or_insert_with(Default::default);
@@ -2147,6 +2196,9 @@ impl ConnectionString {
21472196
k @ "sockettimeoutms" => {
21482197
self.socket_timeout = Some(Duration::from_millis(get_duration!(value, k)));
21492198
}
2199+
k @ "srvmaxhosts" => {
2200+
self.srv_max_hosts = Some(get_u32!(value, k));
2201+
}
21502202
k @ "tls" | k @ "ssl" => {
21512203
let tls = get_bool!(value, k);
21522204

src/client/options/test.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ async fn run_test(test_file: TestFile) {
7575
)
7676
// The Rust driver disallows `maxPoolSize=0`.
7777
|| test_case.description.contains("maxPoolSize=0 does not error")
78+
// TODO RUST-933 implement custom srvServiceName support
79+
|| test_case.description.contains("custom srvServiceName")
7880
{
7981
continue;
8082
}

src/sdam.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub(crate) use self::{
1313
description::{
1414
server::{ServerDescription, TopologyVersion},
1515
topology::{
16+
choose_n,
1617
server_selection::{self, SelectedServer},
1718
verify_max_staleness,
1819
TopologyDescription,

src/sdam/description/topology.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ pub(crate) struct TopologyDescription {
113113

114114
/// The server descriptions of each member of the topology.
115115
pub(crate) servers: HashMap<ServerAddress, ServerDescription>,
116+
117+
/// The maximum number of hosts.
118+
pub(crate) srv_max_hosts: Option<u32>,
116119
}
117120

118121
impl PartialEq for TopologyDescription {
@@ -141,6 +144,7 @@ impl Default for TopologyDescription {
141144
local_threshold: Default::default(),
142145
heartbeat_freq: Default::default(),
143146
servers: Default::default(),
147+
srv_max_hosts: Default::default(),
144148
}
145149
}
146150
}
@@ -177,6 +181,7 @@ impl TopologyDescription {
177181
self.set_name = options.repl_set_name.clone();
178182
self.local_threshold = options.local_threshold;
179183
self.heartbeat_freq = options.heartbeat_freq;
184+
self.srv_max_hosts = options.srv_max_hosts;
180185
}
181186

182187
/// Gets the topology type of the cluster.
@@ -381,9 +386,23 @@ impl TopologyDescription {
381386
/// Syncs the set of servers in the description to those in `hosts`. Servers in the set not
382387
/// already present in the cluster will be added, and servers in the cluster not present in the
383388
/// set will be removed.
384-
pub(crate) fn sync_hosts(&mut self, hosts: &HashSet<ServerAddress>) {
385-
self.add_new_servers_from_addresses(hosts.iter());
389+
pub(crate) fn sync_hosts(&mut self, hosts: HashSet<ServerAddress>) {
386390
self.servers.retain(|host, _| hosts.contains(host));
391+
let mut new = vec![];
392+
for host in hosts {
393+
if !self.servers.contains_key(&host) {
394+
new.push((host.clone(), ServerDescription::new(host)));
395+
}
396+
}
397+
if let Some(max) = self.srv_max_hosts {
398+
let max = max as usize;
399+
if max > 0 && max < self.servers.len() + new.len() {
400+
new = choose_n(&new, max.saturating_sub(self.servers.len()))
401+
.cloned()
402+
.collect();
403+
}
404+
}
405+
self.servers.extend(new);
387406
}
388407

389408
pub(crate) fn transaction_support_status(&self) -> TransactionSupportStatus {
@@ -730,6 +749,11 @@ impl TopologyDescription {
730749
}
731750
}
732751

752+
pub(crate) fn choose_n<T>(values: &[T], n: usize) -> impl Iterator<Item = &T> {
753+
use rand::{prelude::SliceRandom, SeedableRng};
754+
values.choose_multiple(&mut rand::rngs::SmallRng::from_entropy(), n)
755+
}
756+
733757
/// Enum representing whether transactions are supported by the topology.
734758
#[derive(Debug, Clone, Copy, PartialEq)]
735759
pub(crate) enum TransactionSupportStatus {

src/sdam/description/topology/server_selection.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ mod test;
33

44
use std::{collections::HashMap, fmt, ops::Deref, sync::Arc, time::Duration};
55

6-
use rand::{rngs::SmallRng, seq::SliceRandom, SeedableRng};
7-
86
use super::TopologyDescription;
97
use crate::{
108
error::{ErrorKind, Result},
@@ -86,9 +84,7 @@ fn select_server_in_latency_window(in_window: Vec<&Arc<Server>>) -> Option<Arc<S
8684
return Some(in_window[0].clone());
8785
}
8886

89-
let mut rng = SmallRng::from_entropy();
90-
in_window
91-
.choose_multiple(&mut rng, 2)
87+
super::choose_n(&in_window, 2)
9288
.min_by_key(|s| s.operation_count())
9389
.map(|server| (*server).clone())
9490
}

src/sdam/description/topology/server_selection/test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ impl TestTopologyDescription {
5252
local_threshold: None,
5353
heartbeat_freq: heartbeat_frequency,
5454
servers,
55+
srv_max_hosts: None,
5556
}
5657
}
5758
}

src/sdam/srv_polling.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ impl SrvPollingMonitor {
104104

105105
// TODO: RUST-230 Log error with host that was returned.
106106
self.topology_updater
107-
.sync_hosts(lookup.hosts.into_iter().filter_map(Result::ok).collect())
107+
.sync_hosts(lookup.hosts.into_iter().collect())
108108
.await;
109109
}
110110

@@ -120,7 +120,9 @@ impl SrvPollingMonitor {
120120
}
121121
let initial_hostname = self.initial_hostname.clone();
122122
let resolver = self.get_or_create_srv_resolver().await?;
123-
resolver.get_srv_hosts(initial_hostname.as_str()).await
123+
resolver
124+
.get_srv_hosts(initial_hostname.as_str(), crate::srv::DomainMismatch::Skip)
125+
.await
124126
}
125127

126128
async fn get_or_create_srv_resolver(&mut self) -> Result<&SrvResolver> {

src/sdam/srv_polling/test.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,26 @@ lazy_static::lazy_static! {
2626
}
2727

2828
async fn run_test(new_hosts: Result<Vec<ServerAddress>>, expected_hosts: HashSet<ServerAddress>) {
29+
run_test_srv(None, new_hosts, expected_hosts).await
30+
}
31+
32+
async fn run_test_srv(
33+
max_hosts: Option<u32>,
34+
new_hosts: Result<Vec<ServerAddress>>,
35+
expected_hosts: HashSet<ServerAddress>,
36+
) {
37+
let actual = run_test_extra(max_hosts, new_hosts).await;
38+
assert_eq!(expected_hosts, actual);
39+
}
40+
41+
async fn run_test_extra(
42+
max_hosts: Option<u32>,
43+
new_hosts: Result<Vec<ServerAddress>>,
44+
) -> HashSet<ServerAddress> {
2945
let mut options = ClientOptions::new_srv();
3046
options.hosts = DEFAULT_HOSTS.clone();
3147
options.test_options_mut().disable_monitoring_threads = true;
48+
options.srv_max_hosts = max_hosts;
3249
let mut topology = Topology::new(options.clone()).unwrap();
3350
topology.watch().wait_until_initialized().await;
3451
let mut monitor =
@@ -38,12 +55,12 @@ async fn run_test(new_hosts: Result<Vec<ServerAddress>>, expected_hosts: HashSet
3855
.update_hosts(new_hosts.and_then(make_lookup_hosts))
3956
.await;
4057

41-
assert_eq!(expected_hosts, topology.server_addresses());
58+
topology.server_addresses()
4259
}
4360

4461
fn make_lookup_hosts(hosts: Vec<ServerAddress>) -> Result<LookupHosts> {
4562
Ok(LookupHosts {
46-
hosts: hosts.into_iter().map(Result::Ok).collect(),
63+
hosts,
4764
min_ttl: Duration::from_secs(60),
4865
})
4966
}
@@ -136,3 +153,46 @@ async fn load_balanced_no_srv_polling() {
136153
topology.server_addresses()
137154
);
138155
}
156+
157+
// SRV polling with srvMaxHosts MongoClient option: All DNS records are selected (srvMaxHosts = 0)
158+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
159+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
160+
async fn srv_max_hosts_zero() {
161+
let hosts = vec![
162+
localhost_test_build_10gen(27017),
163+
localhost_test_build_10gen(27019),
164+
localhost_test_build_10gen(27020),
165+
];
166+
167+
run_test_srv(None, Ok(hosts.clone()), hosts.clone().into_iter().collect()).await;
168+
run_test_srv(Some(0), Ok(hosts.clone()), hosts.into_iter().collect()).await;
169+
}
170+
171+
// SRV polling with srvMaxHosts MongoClient option: All DNS records are selected (srvMaxHosts >=
172+
// records)
173+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
174+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
175+
async fn srv_max_hosts_gt_actual() {
176+
let hosts = vec![
177+
localhost_test_build_10gen(27019),
178+
localhost_test_build_10gen(27020),
179+
];
180+
181+
run_test_srv(Some(2), Ok(hosts.clone()), hosts.into_iter().collect()).await;
182+
}
183+
184+
// SRV polling with srvMaxHosts MongoClient option: New DNS records are randomly selected
185+
// (srvMaxHosts > 0)
186+
#[cfg_attr(feature = "tokio-runtime", tokio::test)]
187+
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
188+
async fn srv_max_hosts_random() {
189+
let hosts = vec![
190+
localhost_test_build_10gen(27017),
191+
localhost_test_build_10gen(27019),
192+
localhost_test_build_10gen(27020),
193+
];
194+
195+
let actual = run_test_extra(Some(2), Ok(hosts)).await;
196+
assert_eq!(2, actual.len());
197+
assert!(actual.contains(&localhost_test_build_10gen(27017)));
198+
}

src/sdam/topology.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ impl TopologyWorker {
480480

481481
async fn sync_hosts(&mut self, hosts: HashSet<ServerAddress>) -> bool {
482482
let mut new_description = self.topology_description.clone();
483-
new_description.sync_hosts(&hosts);
483+
new_description.sync_hosts(hosts);
484484
self.update_topology(new_description).await
485485
}
486486

0 commit comments

Comments
 (0)