Skip to content

Implement load balancing #1052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokio-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ postgres-protocol = { version = "0.6.5", path = "../postgres-protocol" }
postgres-types = { version = "0.2.4", path = "../postgres-types" }
tokio = { version = "1.27", features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec"] }
rand = "0.8.5"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
socket2 = { version = "0.5", features = ["all"] }
Expand Down
2 changes: 1 addition & 1 deletion tokio-postgres/src/cancel_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where
let has_hostname = config.hostname.is_some();

let socket = connect_socket::connect_socket(
&config.host,
&config.addr,
config.port,
config.connect_timeout,
config.tcp_user_timeout,
Expand Down
16 changes: 13 additions & 3 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use crate::codec::{BackendMessages, FrontendMessage};
#[cfg(feature = "runtime")]
use crate::config::Host;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::copy_out::CopyOutStream;
Expand All @@ -27,6 +25,10 @@ use postgres_protocol::message::{backend::Message, frontend};
use postgres_types::BorrowToSql;
use std::collections::HashMap;
use std::fmt;
#[cfg(feature = "runtime")]
use std::net::IpAddr;
#[cfg(feature = "runtime")]
use std::path::PathBuf;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(feature = "runtime")]
Expand Down Expand Up @@ -153,14 +155,22 @@ impl InnerClient {
#[cfg(feature = "runtime")]
#[derive(Clone)]
pub(crate) struct SocketConfig {
pub host: Host,
pub addr: Addr,
pub hostname: Option<String>,
pub port: u16,
pub connect_timeout: Option<Duration>,
pub tcp_user_timeout: Option<Duration>,
pub keepalive: Option<KeepaliveConfig>,
}

#[cfg(feature = "runtime")]
#[derive(Clone)]
pub(crate) enum Addr {
Tcp(IpAddr),
#[cfg(unix)]
Unix(PathBuf),
}

/// An asynchronous PostgreSQL client.
///
/// The client is one half of what is returned when a connection is established. Users interact with the database
Expand Down
43 changes: 43 additions & 0 deletions tokio-postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ pub enum ChannelBinding {
Require,
}

/// Load balancing configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum LoadBalanceHosts {
/// Make connection attempts to hosts in the order provided.
Disable,
/// Make connection attempts to hosts in a random order.
Random,
}

/// A host specification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
Expand Down Expand Up @@ -129,6 +139,12 @@ pub enum Host {
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and
/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter
/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to
/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried
/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults
/// to `disable`.
///
/// ## Examples
///
Expand Down Expand Up @@ -190,6 +206,7 @@ pub struct Config {
pub(crate) keepalive_config: KeepaliveConfig,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
}

impl Default for Config {
Expand Down Expand Up @@ -222,6 +239,7 @@ impl Config {
},
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
}
}

Expand Down Expand Up @@ -489,6 +507,19 @@ impl Config {
self.channel_binding
}

/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.load_balance_hosts = load_balance_hosts;
self
}

/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.load_balance_hosts
}

fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
Expand Down Expand Up @@ -612,6 +643,18 @@ impl Config {
};
self.channel_binding(channel_binding);
}
"load_balance_hosts" => {
let load_balance_hosts = match value {
"disable" => LoadBalanceHosts::Disable,
"random" => LoadBalanceHosts::Random,
_ => {
return Err(Error::config_parse(Box::new(InvalidValue(
"load_balance_hosts",
))))
}
};
self.load_balance_hosts(load_balance_hosts);
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
Expand Down
93 changes: 71 additions & 22 deletions tokio-postgres/src/connect.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use crate::client::SocketConfig;
use crate::config::{Host, TargetSessionAttrs};
use crate::client::{Addr, SocketConfig};
use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::tls::MakeTlsConnect;
use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use rand::seq::SliceRandom;
use std::task::Poll;
use std::{cmp, io};
use tokio::net;

pub async fn connect<T>(
mut tls: T,
Expand Down Expand Up @@ -40,8 +42,13 @@ where
return Err(Error::config("invalid number of ports".into()));
}

let mut indices = (0..num_hosts).collect::<Vec<_>>();
if config.load_balance_hosts == LoadBalanceHosts::Random {
indices.shuffle(&mut rand::thread_rng());
}

let mut error = None;
for i in 0..num_hosts {
for i in indices {
let host = config.host.get(i);
let hostaddr = config.hostaddr.get(i);
let port = config
Expand All @@ -59,25 +66,15 @@ where
Some(Host::Unix(_)) => None,
None => None,
};
let tls = tls
.make_tls_connect(hostname.as_deref().unwrap_or(""))
.map_err(|e| Error::tls(e.into()))?;

// Try to use the value of hostaddr to establish the TCP connection,
// fallback to host if hostaddr is not present.
let addr = match hostaddr {
Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
None => {
if let Some(host) = host {
host.clone()
} else {
// This is unreachable.
return Err(Error::config("both host and hostaddr are empty".into()));
}
}
None => host.cloned().unwrap(),
};

match connect_once(addr, hostname, port, tls, config).await {
match connect_host(addr, hostname, port, &mut tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
Expand All @@ -86,18 +83,66 @@ where
Err(error.unwrap())
}

async fn connect_once<T>(
async fn connect_host<T>(
host: Host,
hostname: Option<String>,
port: u16,
tls: T,
tls: &mut T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
match host {
Host::Tcp(host) => {
let mut addrs = net::lookup_host((&*host, port))
.await
.map_err(Error::connect)?
.collect::<Vec<_>>();

if config.load_balance_hosts == LoadBalanceHosts::Random {
addrs.shuffle(&mut rand::thread_rng());
}

let mut last_err = None;
for addr in addrs {
match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
.await
{
Ok(stream) => return Ok(stream),
Err(e) => {
last_err = Some(e);
continue;
}
};
}

Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
}
#[cfg(unix)]
Host::Unix(path) => {
connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
}
}
}

async fn connect_once<T>(
addr: Addr,
hostname: Option<&str>,
port: u16,
tls: &mut T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: TlsConnect<Socket>,
T: MakeTlsConnect<Socket>,
{
let socket = connect_socket(
&host,
&addr,
port,
config.connect_timeout,
config.tcp_user_timeout,
Expand All @@ -108,6 +153,10 @@ where
},
)
.await?;

let tls = tls
.make_tls_connect(hostname.unwrap_or(""))
.map_err(|e| Error::tls(e.into()))?;
let has_hostname = hostname.is_some();
let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;

Expand Down Expand Up @@ -152,8 +201,8 @@ where
}

client.set_socket_config(SocketConfig {
host,
hostname,
addr,
hostname: hostname.map(|s| s.to_string()),
port,
connect_timeout: config.connect_timeout,
tcp_user_timeout: config.tcp_user_timeout,
Expand Down
65 changes: 22 additions & 43 deletions tokio-postgres/src/connect_socket.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,50 @@
use crate::config::Host;
use crate::client::Addr;
use crate::keepalive::KeepaliveConfig;
use crate::{Error, Socket};
use socket2::{SockRef, TcpKeepalive};
use std::future::Future;
use std::io;
use std::time::Duration;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{self, TcpStream};
use tokio::time;

pub(crate) async fn connect_socket(
host: &Host,
addr: &Addr,
port: u16,
connect_timeout: Option<Duration>,
#[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option<
Duration,
>,
keepalive_config: Option<&KeepaliveConfig>,
) -> Result<Socket, Error> {
match host {
Host::Tcp(host) => {
let addrs = net::lookup_host((&**host, port))
.await
.map_err(Error::connect)?;
match addr {
Addr::Tcp(ip) => {
let stream =
connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?;

let mut last_err = None;
stream.set_nodelay(true).map_err(Error::connect)?;

for addr in addrs {
let stream =
match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};

stream.set_nodelay(true).map_err(Error::connect)?;

let sock_ref = SockRef::from(&stream);
#[cfg(target_os = "linux")]
{
sock_ref
.set_tcp_user_timeout(tcp_user_timeout)
.map_err(Error::connect)?;
}

if let Some(keepalive_config) = keepalive_config {
sock_ref
.set_tcp_keepalive(&TcpKeepalive::from(keepalive_config))
.map_err(Error::connect)?;
}
let sock_ref = SockRef::from(&stream);
#[cfg(target_os = "linux")]
{
sock_ref
.set_tcp_user_timeout(tcp_user_timeout)
.map_err(Error::connect)?;
}

return Ok(Socket::new_tcp(stream));
if let Some(keepalive_config) = keepalive_config {
sock_ref
.set_tcp_keepalive(&TcpKeepalive::from(keepalive_config))
.map_err(Error::connect)?;
}

Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
Ok(Socket::new_tcp(stream))
}
#[cfg(unix)]
Host::Unix(path) => {
let path = path.join(format!(".s.PGSQL.{}", port));
Addr::Unix(dir) => {
let path = dir.join(format!(".s.PGSQL.{}", port));
let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?;
Ok(Socket::new_unix(socket))
}
Expand Down