Skip to content

RUST-229 Parse IPv6 addresses in the connection string #1242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 98 additions & 71 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{
convert::TryFrom,
fmt::{self, Display, Formatter, Write},
hash::{Hash, Hasher},
net::Ipv6Addr,
path::PathBuf,
str::FromStr,
time::Duration,
Expand Down Expand Up @@ -128,9 +129,29 @@ impl<'de> Deserialize<'de> for ServerAddress {
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
Self::parse(s.as_str())
.map_err(|e| <D::Error as serde::de::Error>::custom(format!("{}", e)))
#[derive(Deserialize)]
#[serde(untagged)]
enum ServerAddressHelper {
String(String),
Object { host: String, port: Option<u16> },
}

let helper = ServerAddressHelper::deserialize(deserializer)?;
match helper {
ServerAddressHelper::String(string) => {
Self::parse(string).map_err(serde::de::Error::custom)
}
ServerAddressHelper::Object { host, port } => {
#[cfg(unix)]
if host.ends_with("sock") {
return Ok(Self::Unix {
path: PathBuf::from(host),
});
}

Ok(Self::Tcp { host, port })
}
}
}
}

Expand Down Expand Up @@ -185,74 +206,95 @@ impl FromStr for ServerAddress {
}

impl ServerAddress {
/// Parses an address string into a `ServerAddress`.
/// Parses an address string into a [`ServerAddress`].
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
let address = address.as_ref();
// checks if the address is a unix domain socket
#[cfg(unix)]
{
if address.ends_with(".sock") {
return Ok(ServerAddress::Unix {

if address.ends_with(".sock") {
#[cfg(unix)]
{
let address = percent_decode(address, "unix domain sockets must be URL-encoded")?;
return Ok(Self::Unix {
path: PathBuf::from(address),
});
}
#[cfg(not(unix))]
return Err(ErrorKind::InvalidArgument {
message: "unix domain sockets are not supported on this platform".to_string(),
}
.into());
}
let mut parts = address.split(':');
let hostname = match parts.next() {
Some(part) => {
if part.is_empty() {
return Err(ErrorKind::InvalidArgument {
message: format!(
"invalid server address: \"{}\"; hostname cannot be empty",
address
),
}
.into());

let (hostname, port) = if let Some(ip_literal) = address.strip_prefix("[") {
let Some((hostname, port)) = ip_literal.split_once("]") else {
return Err(ErrorKind::InvalidArgument {
message: format!(
"invalid server address {}: missing closing ']' in IP literal hostname",
address
),
}
part
}
None => {
.into());
};

if let Err(parse_error) = Ipv6Addr::from_str(hostname) {
return Err(ErrorKind::InvalidArgument {
message: format!("invalid server address: \"{}\"", address),
message: format!("invalid server address {}: {}", address, parse_error),
}
.into())
.into());
}
};

let port = match parts.next() {
Some(part) => {
let port = u16::from_str(part).map_err(|_| ErrorKind::InvalidArgument {
let port = if port.is_empty() {
None
} else if let Some(port) = port.strip_prefix(":") {
Some(port)
} else {
return Err(ErrorKind::InvalidArgument {
message: format!(
"port must be valid 16-bit unsigned integer, instead got: {}",
part
"invalid server address {}: the hostname can only be followed by a port \
prefixed with ':', got {}",
address, port
),
})?;

if port == 0 {
return Err(ErrorKind::InvalidArgument {
message: format!(
"invalid server address: \"{}\"; port must be non-zero",
address
),
}
.into());
}
if parts.next().is_some() {
.into());
};

(hostname, port)
} else {
match address.split_once(":") {
Some((hostname, port)) => (hostname, Some(port)),
None => (address, None),
}
};

if hostname.is_empty() {
return Err(ErrorKind::InvalidArgument {
message: format!(
"invalid server address {}: the hostname cannot be empty",
address
),
}
.into());
}

let port = if let Some(port) = port {
match u16::from_str(port) {
Ok(0) | Err(_) => {
return Err(ErrorKind::InvalidArgument {
message: format!(
"address \"{}\" contains more than one unescaped ':'",
address
"invalid server address {}: the port must be an integer between 1 and \
65535, got {}",
address, port
),
}
.into());
.into())
}

Some(port)
Ok(port) => Some(port),
}
None => None,
} else {
None
};

Ok(ServerAddress::Tcp {
Ok(Self::Tcp {
host: hostname.to_lowercase(),
port,
})
Expand Down Expand Up @@ -1165,6 +1207,7 @@ impl ClientOptions {
.iter()
.filter_map(|addr| match addr {
ServerAddress::Tcp { host, .. } => Some(host.to_ascii_lowercase()),
#[cfg(unix)]
_ => None,
})
.collect()
Expand Down Expand Up @@ -1440,31 +1483,15 @@ impl ConnectionString {
None => (None, None),
};

let mut host_list = Vec::with_capacity(hosts_section.len());
for host in hosts_section.split(',') {
let address = if host.ends_with(".sock") {
#[cfg(unix)]
{
ServerAddress::parse(percent_decode(
host,
"Unix domain sockets must be URL-encoded",
)?)
}
#[cfg(not(unix))]
return Err(ErrorKind::InvalidArgument {
message: "Unix domain sockets are not supported on this platform".to_string(),
}
.into());
} else {
ServerAddress::parse(host)
}?;
host_list.push(address);
}
let hosts = hosts_section
.split(',')
.map(ServerAddress::parse)
.collect::<Result<Vec<ServerAddress>>>()?;

let host_info = if !srv {
HostInfo::HostIdentifiers(host_list)
HostInfo::HostIdentifiers(hosts)
} else {
match &host_list[..] {
match &hosts[..] {
[ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()),
[ServerAddress::Tcp {
host: _,
Expand Down
93 changes: 45 additions & 48 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
bson::{Bson, Document},
bson_util::get_int,
client::options::{ClientOptions, ConnectionString, ServerAddress},
error::{Error, ErrorKind, Result},
error::ErrorKind,
test::spec::deserialize_spec_tests,
Client,
};
Expand All @@ -22,13 +22,6 @@ static SKIPPED_TESTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
"maxPoolSize=0 does not error",
// TODO RUST-226: unskip this test
"Valid tlsCertificateKeyFilePassword is parsed correctly",
// TODO RUST-229: unskip the following tests
"Single IP literal host without port",
"Single IP literal host with port",
"Multiple hosts (mixed formats)",
"User info for single IP literal host without database",
"User info for single IP literal host with database",
"User info for multiple hosts with database",
];

// TODO RUST-1896: unskip this test when openssl-tls is enabled
Expand Down Expand Up @@ -65,43 +58,11 @@ struct TestCase {
uri: String,
valid: bool,
warning: Option<bool>,
hosts: Option<Vec<TestServerAddress>>,
hosts: Option<Vec<ServerAddress>>,
auth: Option<TestAuth>,
options: Option<Document>,
}

// The connection string tests' representation of a server address. We use this indirection to avoid
// deserialization failures when the tests specify an IPv6 address.
//
// TODO RUST-229: remove this struct and deserialize directly into ServerAddress
#[derive(Debug, Deserialize)]
struct TestServerAddress {
#[serde(rename = "type")]
host_type: String,
host: String,
port: Option<u16>,
}

impl TryFrom<&TestServerAddress> for ServerAddress {
type Error = Error;

fn try_from(test_server_address: &TestServerAddress) -> Result<Self> {
if test_server_address.host_type.as_str() == "ip_literal" {
return Err(ErrorKind::Internal {
message: "test using ip_literal host type should be skipped".to_string(),
}
.into());
}

let mut address = Self::parse(&test_server_address.host)?;
if let ServerAddress::Tcp { ref mut port, .. } = address {
*port = test_server_address.port;
}

Ok(address)
}
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
struct TestAuth {
Expand Down Expand Up @@ -138,14 +99,8 @@ async fn run_tests(path: &[&str], skipped_files: &[&str]) {
let client_options = client_options_result.expect(&test_case.description);

if let Some(ref expected_hosts) = test_case.hosts {
let expected_hosts = expected_hosts
.iter()
.map(TryFrom::try_from)
.collect::<Result<Vec<ServerAddress>>>()
.expect(&test_case.description);

assert_eq!(
client_options.hosts, expected_hosts,
&client_options.hosts, expected_hosts,
"{}",
test_case.description
);
Expand Down Expand Up @@ -364,3 +319,45 @@ async fn options_enforce_min_heartbeat_frequency() {

Client::with_options(options).unwrap_err();
}

#[test]
fn invalid_ipv6() {
// invalid hostname for ipv6
let address = "[localhost]:27017";
let error = ServerAddress::parse(address).unwrap_err();
let message = error.message().unwrap();
assert!(message.contains("invalid IPv6 address syntax"), "{message}");

// invalid character after hostname
let address = "[::1]a";
let error = ServerAddress::parse(address).unwrap_err();
let message = error.message().unwrap();
assert!(
message.contains("the hostname can only be followed by a port"),
"{message}"
);

// missing bracket
let address = "[::1:27017";
let error = ServerAddress::parse(address).unwrap_err();
let message = error.message().unwrap();
assert!(message.contains("missing closing ']'"), "{message}");

// extraneous bracket
let address = "[::1]:27017]";
let error = ServerAddress::parse(address).unwrap_err();
let message = error.message().unwrap();
assert!(message.contains("the port must be an integer"), "{message}");
}

#[cfg(not(unix))]
#[test]
fn unix_domain_socket_not_allowed() {
let address = "address.sock";
let error = ServerAddress::parse(address).unwrap_err();
let message = error.message().unwrap();
assert!(
message.contains("not supported on this platform"),
"{message}"
);
}