Skip to content

RUST-950 Enable connection to a load balancer #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3f1c162
add a non-public load balancer feature flag
abr-egn Jul 28, 2021
de0deab
recognize loadBalanced URL param
abr-egn Jul 28, 2021
797f488
parse TXT option
abr-egn Jul 28, 2021
1fe9d68
validate options
abr-egn Jul 28, 2021
4728d7f
introduce topology and server types
abr-egn Jul 28, 2021
10537a3
mark connection pool ready
abr-egn Jul 29, 2021
737b414
create initial server with LoadBalancer type
abr-egn Jul 29, 2021
83f0d8d
send loadBalanced hello parameter, check for serviceId in reply
abr-egn Jul 29, 2021
bb19694
include service_id in TestIsMasterCommandResponse
abr-egn Jul 29, 2021
54916f3
add generation map
abr-egn Jul 30, 2021
45b7963
determine connection generation from service id, and use such in stal…
abr-egn Aug 2, 2021
44bba33
track generation and connection count
abr-egn Aug 3, 2021
c47e59a
fix tests
abr-egn Aug 3, 2021
31fa505
rustfmt
abr-egn Aug 3, 2021
b6d0b2e
clippy
abr-egn Aug 3, 2021
2ee6e39
remodel generation
abr-egn Aug 5, 2021
f5630db
rustfmt
abr-egn Aug 5, 2021
1c8c130
lint
abr-egn Aug 5, 2021
bd6fe46
distinguish pre-hello failures from post-hello
abr-egn Aug 6, 2021
8d8a54f
fix sdam test
abr-egn Aug 6, 2021
b3c4ad1
rustfmt
abr-egn Aug 6, 2021
518896e
fix TXT option case match
abr-egn Aug 9, 2021
fb6c86d
panic in test cfg on load balanced mismatch
abr-egn Aug 9, 2021
334a463
fix is_before_completion
abr-egn Aug 9, 2021
d21165b
tidy macro
abr-egn Aug 9, 2021
4c38e7a
rustfmt and clippy
abr-egn Aug 9, 2021
514328c
another case fix
abr-egn Aug 9, 2021
26203f4
review updats
abr-egn Aug 10, 2021
78bfb11
reuse is_load_balanced
abr-egn Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion src/client/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,16 @@ pub struct ClientOptions {
#[builder(default)]
#[cfg(test)]
pub(crate) heartbeat_freq_test: Option<Duration>,

/// Allow use of the `load_balanced` option.
// TODO RUST-653 Remove this when load balancer work is ready for release.
#[builder(default, setter(skip))]
#[serde(skip)]
pub(crate) allow_load_balanced: bool,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This awkward double-flag construction allows load_balanced to be populated by the URI/TXT parsing as it will be on release, but prevents that from being used outside of tests that have access to set allow_load_balanced to be true.


/// Whether or not the client is connecting to a MongoDB cluster through a load balancer.
#[builder(default, setter(skip))]
pub(crate) load_balanced: Option<bool>,
}

fn default_hosts() -> Vec<ServerAddress> {
Expand Down Expand Up @@ -689,6 +699,7 @@ struct ClientOptionsParser {
auth_mechanism_properties: Option<Document>,
read_preference: Option<ReadPreference>,
read_preference_tags: Option<Vec<TagSet>>,
load_balanced: Option<bool>,
original_uri: String,
}

Expand Down Expand Up @@ -921,6 +932,8 @@ impl From<ClientOptionsParser> for ClientOptions {
server_api: None,
#[cfg(test)]
heartbeat_freq_test: None,
allow_load_balanced: false,
load_balanced: parser.load_balanced,
}
}
}
Expand Down Expand Up @@ -1086,6 +1099,10 @@ impl ClientOptions {
options.repl_set_name = Some(replica_set);
}
}

if options.load_balanced.is_none() {
options.load_balanced = config.load_balanced;
}
}

options.validate()?;
Expand All @@ -1108,7 +1125,7 @@ impl ClientOptions {
}
}

/// Ensure the options set are valid, returning an error descirbing the problem if they are not.
/// Ensure the options set are valid, returning an error describing the problem if they are not.
pub(crate) fn validate(&self) -> Result<()> {
if let Some(true) = self.direct_connection {
if self.hosts.len() > 1 {
Expand All @@ -1122,6 +1139,36 @@ impl ClientOptions {
if let Some(ref write_concern) = self.write_concern {
write_concern.validate()?;
}

if !self.allow_load_balanced && self.load_balanced.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "loadBalanced is not supported".to_string(),
}
.into());
}

if self.load_balanced.unwrap_or(false) {
if self.hosts.len() > 1 {
return Err(ErrorKind::InvalidArgument {
message: "cannot specify multiple seeds with loadBalanced=true".to_string(),
}
.into());
}
if self.repl_set_name.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "cannot specify replicaSet with loadBalanced=true".to_string(),
}
.into());
}
if self.direct_connection == Some(true) {
return Err(ErrorKind::InvalidArgument {
message: "cannot specify directConnection=true with loadBalanced=true"
.to_string(),
}
.into());
}
}

Ok(())
}

Expand Down Expand Up @@ -1677,6 +1724,9 @@ impl ClientOptionsParser {
let mut write_concern = self.write_concern.get_or_insert_with(Default::default);
write_concern.journal = Some(get_bool!(value, k));
}
k @ "loadbalanced" => {
self.load_balanced = Some(get_bool!(value, k));
}
k @ "localthresholdms" => {
self.local_threshold = Some(Duration::from_millis(get_duration!(value, k)))
}
Expand Down
61 changes: 49 additions & 12 deletions src/cmap/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ use derivative::Derivative;
use self::wire::Message;
use super::manager::PoolManager;
use crate::{
cmap::options::{ConnectionOptions, StreamOptions},
error::{ErrorKind, Result},
bson::oid::ObjectId,
cmap::{
options::{ConnectionOptions, StreamOptions},
PoolGeneration,
},
error::{load_balanced_mode_mismatch, ErrorKind, Result},
event::cmap::{
CmapEventHandler,
ConnectionCheckedInEvent,
Expand Down Expand Up @@ -46,7 +50,7 @@ pub struct ConnectionInfo {
pub(crate) struct Connection {
pub(super) id: u32,
pub(super) address: ServerAddress,
pub(crate) generation: u32,
pub(crate) generation: ConnectionGeneration,

/// The cached StreamDescription from the connection's handshake.
pub(super) stream_description: Option<StreamDescription>,
Expand Down Expand Up @@ -90,7 +94,7 @@ impl Connection {

let conn = Self {
id,
generation,
generation: ConnectionGeneration::Normal(generation),
pool_manager: None,
command_executing: false,
ready_and_available_time: None,
Expand All @@ -106,10 +110,16 @@ impl Connection {

/// Constructs and connects a new connection.
pub(super) async fn connect(pending_connection: PendingConnection) -> Result<Self> {
let generation = match pending_connection.generation {
PoolGeneration::Normal(gen) => gen,
PoolGeneration::LoadBalanced(_) => 0, /* Placeholder; will be overwritten in
* `ConnectionEstablisher::
* establish_connection`. */
};
Self::new(
pending_connection.id,
pending_connection.address.clone(),
pending_connection.generation,
generation,
pending_connection.options,
)
.await
Expand Down Expand Up @@ -181,11 +191,6 @@ impl Connection {
.unwrap_or(false)
}

/// Checks if the connection is stale.
pub(super) fn is_stale(&self, current_generation: u32) -> bool {
self.generation != current_generation
}

/// Checks if the connection is currently executing an operation.
pub(super) fn is_executing(&self) -> bool {
self.command_executing
Expand Down Expand Up @@ -300,7 +305,7 @@ impl Connection {
Connection {
id: self.id,
address: self.address.clone(),
generation: self.generation,
generation: self.generation.clone(),
stream: std::mem::replace(&mut self.stream, AsyncStream::Null),
handler: self.handler.take(),
stream_description: self.stream_description.take(),
Expand Down Expand Up @@ -335,6 +340,38 @@ impl Drop for Connection {
}
}

#[derive(Debug, Clone)]
pub(crate) enum ConnectionGeneration {
Normal(u32),
LoadBalanced {
generation: u32,
service_id: ObjectId,
},
}

impl ConnectionGeneration {
pub(crate) fn service_id(&self) -> Option<ObjectId> {
match self {
ConnectionGeneration::Normal(_) => None,
ConnectionGeneration::LoadBalanced { service_id, .. } => Some(*service_id),
}
}

pub(crate) fn is_stale(&self, current_generation: &PoolGeneration) -> bool {
match (self, current_generation) {
(ConnectionGeneration::Normal(cgen), PoolGeneration::Normal(pgen)) => cgen != pgen,
(
ConnectionGeneration::LoadBalanced {
generation: cgen,
service_id,
},
PoolGeneration::LoadBalanced(gen_map),
) => cgen != gen_map.get(service_id).unwrap_or(&0),
_ => load_balanced_mode_mismatch!(false),
}
}
}

/// Struct encapsulating the information needed to establish a `Connection`.
///
/// Creating a `PendingConnection` contributes towards the total connection count of a pool, despite
Expand All @@ -344,7 +381,7 @@ impl Drop for Connection {
pub(super) struct PendingConnection {
pub(super) id: u32,
pub(super) address: ServerAddress,
pub(super) generation: u32,
pub(super) generation: PoolGeneration,
pub(super) options: Option<ConnectionOptions>,
}

Expand Down
19 changes: 18 additions & 1 deletion src/cmap/establish/handshake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
bson::{doc, Bson, Document},
client::auth::{ClientFirst, FirstRound},
cmap::{options::ConnectionPoolOptions, Command, Connection, StreamDescription},
error::Result,
error::{ErrorKind, Result},
is_master::{is_master_command, run_is_master, IsMasterReply},
options::{AuthMechanism, ClientOptions, Credential, DriverInfo, ServerApi},
};
Expand Down Expand Up @@ -177,6 +177,10 @@ impl Handshaker {
command.target_db = cred.resolved_source().to_string();
credential = Some(cred);
}

if options.load_balanced {
command.body.insert("loadBalanced", true);
}
}

command.body.insert("client", metadata);
Expand All @@ -194,6 +198,16 @@ impl Handshaker {
let client_first = set_speculative_auth_info(&mut command.body, self.credential.as_ref())?;

let mut is_master_reply = run_is_master(command, conn).await?;
if self.command.body.contains_key("loadBalanced")
&& is_master_reply.command_response.service_id.is_none()
{
return Err(ErrorKind::IncompatibleServer {
message: "Driver attempted to initialize in load balancing mode, but the server \
does not support this mode."
.to_string(),
}
.into());
}
conn.stream_description = Some(StreamDescription::from_is_master(is_master_reply.clone()));

// Record the client's message and the server's response from speculative authentication if
Expand Down Expand Up @@ -232,6 +246,7 @@ pub(crate) struct HandshakerOptions {
credential: Option<Credential>,
driver_info: Option<DriverInfo>,
server_api: Option<ServerApi>,
load_balanced: bool,
}

impl From<ConnectionPoolOptions> for HandshakerOptions {
Expand All @@ -241,6 +256,7 @@ impl From<ConnectionPoolOptions> for HandshakerOptions {
credential: options.credential,
driver_info: options.driver_info,
server_api: options.server_api,
load_balanced: options.load_balanced.unwrap_or(false),
}
}
}
Expand All @@ -252,6 +268,7 @@ impl From<ClientOptions> for HandshakerOptions {
credential: options.credential,
driver_info: options.driver_info,
server_api: options.server_api,
load_balanced: options.load_balanced.unwrap_or(false),
}
}
}
Expand Down
71 changes: 62 additions & 9 deletions src/cmap/establish/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ pub(super) mod handshake;
mod test;

use self::handshake::Handshaker;
use super::{conn::PendingConnection, options::ConnectionPoolOptions, Connection};
use super::{
conn::{ConnectionGeneration, PendingConnection},
options::ConnectionPoolOptions,
Connection,
PoolGeneration,
};
use crate::{
client::{auth::Credential, options::ServerApi},
error::Result,
error::{Error as MongoError, ErrorKind},
runtime::HttpClient,
sdam::HandshakePhase,
};

/// Contains the logic to establish a connection, including handshaking, authenticating, and
Expand Down Expand Up @@ -38,26 +44,73 @@ impl ConnectionEstablisher {
pub(super) async fn establish_connection(
&self,
pending_connection: PendingConnection,
) -> Result<Connection> {
let mut connection = Connection::connect(pending_connection).await?;
) -> std::result::Result<Connection, EstablishError> {
let pool_gen = pending_connection.generation.clone();
let mut connection = Connection::connect(pending_connection)
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;

let first_round = self
let handshake = self
.handshaker
.handshake(&mut connection)
.await?
.first_round;
.await
.map_err(|e| EstablishError::pre_hello(e, pool_gen.clone()))?;
let service_id = handshake.is_master_reply.command_response.service_id;

// If the handshake response had a `serviceId` field, this is a connection to a load
// balancer and must derive its generation from the service_generations map.
match (pool_gen, service_id) {
(PoolGeneration::Normal(_), _) => {}
(PoolGeneration::LoadBalanced(gen_map), Some(service_id)) => {
connection.generation = ConnectionGeneration::LoadBalanced {
generation: *gen_map.get(&service_id).unwrap_or(&0),
service_id,
};
}
_ => {
return Err(EstablishError::post_hello(
ErrorKind::Internal {
message: "load-balanced mode mismatch".to_string(),
}
.into(),
connection.generation.clone(),
));
}
}

if let Some(ref credential) = self.credential {
credential
.authenticate_stream(
&mut connection,
&self.http_client,
self.server_api.as_ref(),
first_round,
handshake.first_round,
)
.await?;
.await
.map_err(|e| EstablishError::post_hello(e, connection.generation.clone()))?
}

Ok(connection)
}
}

#[derive(Debug, Clone)]
pub(crate) struct EstablishError {
pub(crate) cause: MongoError,
pub(crate) handshake_phase: HandshakePhase,
}

impl EstablishError {
fn pre_hello(cause: MongoError, generation: PoolGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PreHello { generation },
}
}
fn post_hello(cause: MongoError, generation: ConnectionGeneration) -> Self {
Self {
cause,
handshake_phase: HandshakePhase::PostHello { generation },
}
}
}
Loading