,
@@ -95,7 +95,7 @@ pub(crate) struct QueryPlanLogger<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq
#[cfg(feature = "sqlite")]
impl<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq, P: Debug> QueryPlanLogger<'q, O, R, P> {
- pub(crate) fn new(sql: &'q str, settings: LogSettings) -> Self {
+ pub fn new(sql: &'q str, settings: LogSettings) -> Self {
Self {
sql,
unknown_operations: HashSet::new(),
@@ -105,19 +105,19 @@ impl<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq, P: Debug> QueryPlanLogger<'
}
}
- pub(crate) fn add_result(&mut self, result: R) {
+ pub fn add_result(&mut self, result: R) {
self.results.insert(result);
}
- pub(crate) fn add_program(&mut self, program: Vec) {
+ pub fn add_program(&mut self, program: Vec
) {
self.program = program;
}
- pub(crate) fn add_unknown_operation(&mut self, operation: O) {
+ pub fn add_unknown_operation(&mut self, operation: O) {
self.unknown_operations.insert(operation);
}
- pub(crate) fn finish(&self) {
+ pub fn finish(&self) {
let lvl = self.settings.statements_level;
if let Some(lvl) = lvl
diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs
index 5526dda795..e55e8da1ab 100644
--- a/sqlx-core/src/migrate/migrator.rs
+++ b/sqlx-core/src/migrate/migrator.rs
@@ -38,7 +38,7 @@ impl Migrator {
/// ```rust,no_run
/// # use sqlx_core::migrate::MigrateError;
/// # fn main() -> Result<(), MigrateError> {
- /// # sqlx_rt::block_on(async move {
+ /// # sqlx::__rt::test_block_on(async move {
/// # use sqlx_core::migrate::Migrator;
/// use std::path::Path;
///
@@ -79,7 +79,7 @@ impl Migrator {
/// # use sqlx_core::migrate::MigrateError;
/// # #[cfg(feature = "sqlite")]
/// # fn main() -> Result<(), MigrateError> {
- /// # sqlx_rt::block_on(async move {
+ /// # sqlx::__rt::test_block_on(async move {
/// # use sqlx_core::migrate::Migrator;
/// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
/// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
@@ -154,7 +154,7 @@ impl Migrator {
/// # use sqlx_core::migrate::MigrateError;
/// # #[cfg(feature = "sqlite")]
/// # fn main() -> Result<(), MigrateError> {
- /// # sqlx_rt::block_on(async move {
+ /// # sqlx::__rt::test_block_on(async move {
/// # use sqlx_core::migrate::Migrator;
/// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
/// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs
index cd0cdca39d..609f4fdeaa 100644
--- a/sqlx-core/src/migrate/source.rs
+++ b/sqlx-core/src/migrate/source.rs
@@ -1,8 +1,8 @@
use crate::error::BoxDynError;
+use crate::fs;
use crate::migrate::{Migration, MigrationType};
use futures_core::future::BoxFuture;
-use futures_util::TryStreamExt;
-use sqlx_rt::fs;
+
use std::borrow::Cow;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
@@ -20,21 +20,16 @@ pub trait MigrationSource<'s>: Debug {
impl<'s> MigrationSource<'s> for &'s Path {
fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> {
Box::pin(async move {
- #[allow(unused_mut)]
let mut s = fs::read_dir(self.canonicalize()?).await?;
let mut migrations = Vec::new();
- #[cfg(feature = "_rt-tokio")]
- let mut s = tokio_stream::wrappers::ReadDirStream::new(s);
-
- while let Some(entry) = s.try_next().await? {
- if !entry.metadata().await?.is_file() {
+ while let Some(entry) = s.next().await? {
+ if !entry.metadata.is_file() {
// not a file; ignore
continue;
}
- let file_name = entry.file_name();
- let file_name = file_name.to_string_lossy();
+ let file_name = entry.file_name.to_string_lossy();
let parts = file_name.splitn(2, '_').collect::>();
@@ -52,7 +47,7 @@ impl<'s> MigrationSource<'s> for &'s Path {
.replace('_', " ")
.to_owned();
- let sql = fs::read_to_string(&entry.path()).await?;
+ let sql = fs::read_to_string(&entry.path).await?;
migrations.push(Migration::new(
version,
diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs
index 8585f7cf99..2b6558c98e 100644
--- a/sqlx-core/src/mssql/connection/mod.rs
+++ b/sqlx-core/src/mssql/connection/mod.rs
@@ -37,22 +37,10 @@ impl Connection for MssqlConnection {
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
// NOTE: there does not seem to be a clean shutdown packet to send to MSSQL
- #[cfg(feature = "_rt-async-std")]
- {
- use std::future::ready;
- use std::net::Shutdown;
-
- ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed()
- }
-
- #[cfg(feature = "_rt-tokio")]
- {
- use sqlx_rt::AsyncWriteExt;
-
- // FIXME: This is equivalent to Shutdown::Write, not Shutdown::Both like above
- // https://docs.rs/tokio/1.0.1/tokio/io/trait.AsyncWriteExt.html#method.shutdown
- async move { self.stream.shutdown().await.map_err(Into::into) }.boxed()
- }
+ Box::pin(async move {
+ self.stream.shutdown().await?;
+ Ok(())
+ })
}
fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
@@ -78,6 +66,6 @@ impl Connection for MssqlConnection {
#[doc(hidden)]
fn should_flush(&self) -> bool {
- !self.stream.wbuf.is_empty()
+ !self.stream.write_buffer().is_empty()
}
}
diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs
index 1ce061d508..79888429f8 100644
--- a/sqlx-core/src/mssql/connection/stream.rs
+++ b/sqlx-core/src/mssql/connection/stream.rs
@@ -1,11 +1,11 @@
use std::ops::{Deref, DerefMut};
+use std::sync::Arc;
-use bytes::{Bytes, BytesMut};
-use sqlx_rt::TcpStream;
+use bytes::{BufMut, Bytes, BytesMut};
use crate::error::Error;
use crate::ext::ustr::UStr;
-use crate::io::{BufStream, Encode};
+use crate::io::Encode;
use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::{Done, Status as DoneStatus};
use crate::mssql::protocol::env_change::EnvChange;
@@ -19,12 +19,11 @@ use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError};
-use crate::net::MaybeTlsStream;
+use crate::net::{BufferedSocket, Socket, SocketIntoBox};
use crate::HashMap;
-use std::sync::Arc;
pub(crate) struct MssqlStream {
- inner: BufStream>,
+ inner: BufferedSocket>,
// how many Done (or Error) we are currently waiting for
pub(crate) pending_done_count: usize,
@@ -45,12 +44,10 @@ pub(crate) struct MssqlStream {
impl MssqlStream {
pub(super) async fn connect(options: &MssqlConnectOptions) -> Result {
- let inner = BufStream::new(MaybeTlsStream::Raw(
- TcpStream::connect((&*options.host, options.port)).await?,
- ));
+ let socket = crate::net::connect_tcp(&options.host, options.port, SocketIntoBox).await?;
Ok(Self {
- inner,
+ inner: BufferedSocket::new(socket),
columns: Default::default(),
column_names: Default::default(),
response: None,
@@ -68,6 +65,7 @@ impl MssqlStream {
// write out the packet header, leaving room for setting the packet length later
+ let starting_buf_len = self.inner.write_buffer().get().len();
let mut len_offset = 0;
self.inner.write_with(
@@ -78,15 +76,18 @@ impl MssqlStream {
server_process_id: 0,
packet_id: 1,
},
+ // updated by `PacketHeader::encode()`
&mut len_offset,
);
// write out the payload
self.inner.write(payload);
+ let buf = self.inner.write_buffer_mut().get_mut();
+
// overwrite the packet length now that we know it
- let len = self.inner.wbuf.len();
- self.inner.wbuf[len_offset..(len_offset + 2)].copy_from_slice(&(len as u16).to_be_bytes());
+ let len = buf.len() - starting_buf_len;
+ (&mut buf[len_offset..(len_offset + 2)]).put_u16(len as u16);
}
// receive the next packet from the database
@@ -106,10 +107,13 @@ impl MssqlStream {
let mut payload = BytesMut::new();
loop {
- self.inner
- .read_raw_into(&mut payload, (header.length - 8) as usize)
+ let chunk = self
+ .inner
+ .read_buffered((header.length - 8) as usize)
.await?;
+ payload.unsplit(chunk);
+
if header.status.contains(Status::END_OF_MESSAGE) {
break;
}
@@ -202,7 +206,7 @@ impl MssqlStream {
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
- if !self.wbuf.is_empty() {
+ if !self.write_buffer().is_empty() {
self.flush().await?;
}
@@ -222,7 +226,7 @@ impl MssqlStream {
}
impl Deref for MssqlStream {
- type Target = BufStream>;
+ type Target = BufferedSocket>;
fn deref(&self) -> &Self::Target {
&self.inner
diff --git a/sqlx-core/src/mysql/connection/auth.rs b/sqlx-core/src/mysql/connection/auth.rs
index 237fd55288..bb04684dc3 100644
--- a/sqlx-core/src/mysql/connection/auth.rs
+++ b/sqlx-core/src/mysql/connection/auth.rs
@@ -131,7 +131,7 @@ async fn encrypt_rsa<'s>(
) -> Result, Error> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
- if stream.is_tls() {
+ if stream.is_tls {
// If in a TLS stream, send the password directly in clear text
return Ok(to_asciz(password));
}
diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs
index 5352b1a10c..1e65ca710b 100644
--- a/sqlx-core/src/mysql/connection/establish.rs
+++ b/sqlx-core/src/mysql/connection/establish.rs
@@ -1,26 +1,77 @@
use bytes::buf::Buf;
use bytes::Bytes;
+use futures_core::future::BoxFuture;
use crate::common::StatementCache;
use crate::error::Error;
+use crate::mysql::collation::{CharSet, Collation};
use crate::mysql::connection::{tls, MySqlStream, MAX_PACKET_SIZE};
use crate::mysql::protocol::connect::{
AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse,
};
use crate::mysql::protocol::Capabilities;
-use crate::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
+use crate::mysql::{MySqlConnectOptions, MySqlConnection};
+use crate::net::{Socket, WithSocket};
impl MySqlConnection {
pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result {
- let mut stream: MySqlStream = MySqlStream::connect(options).await?;
+ let do_handshake = DoHandshake::new(options)?;
- // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
+ let handshake = match &options.socket {
+ Some(path) => crate::net::connect_uds(path, do_handshake).await?,
+ None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?,
+ };
+
+ let stream = handshake.await?;
+
+ Ok(Self {
+ stream,
+ transaction_depth: 0,
+ cache_statement: StatementCache::new(options.statement_cache_capacity),
+ log_settings: options.log_settings.clone(),
+ })
+ }
+}
+
+struct DoHandshake<'a> {
+ options: &'a MySqlConnectOptions,
+ charset: CharSet,
+ collation: Collation,
+}
+
+impl<'a> DoHandshake<'a> {
+ fn new(options: &'a MySqlConnectOptions) -> Result {
+ let charset: CharSet = options.charset.parse()?;
+ let collation: Collation = options
+ .collation
+ .as_deref()
+ .map(|collation| collation.parse())
+ .transpose()?
+ .unwrap_or_else(|| charset.default_collation());
+
+ Ok(Self {
+ options,
+ charset,
+ collation,
+ })
+ }
+
+ async fn do_handshake(self, socket: S) -> Result {
+ let DoHandshake {
+ options,
+ charset,
+ collation,
+ } = self;
+
+ let mut stream = MySqlStream::with_socket(charset, collation, options, socket);
+
+ // https://dev.mysql.com/doc/internals/en/connection-phase.html
// https://mariadb.com/kb/en/connection/
let handshake: Handshake = stream.recv_packet().await?.decode()?;
let mut plugin = handshake.auth_plugin;
- let mut nonce = handshake.auth_plugin_data;
+ let nonce = handshake.auth_plugin_data;
// FIXME: server version parse is a bit ugly
// expecting MAJOR.MINOR.PATCH
@@ -54,39 +105,7 @@ impl MySqlConnection {
stream.capabilities &= handshake.server_capabilities;
stream.capabilities |= Capabilities::PROTOCOL_41;
- if matches!(options.ssl_mode, MySqlSslMode::Disabled) {
- // remove the SSL capability if SSL has been explicitly disabled
- stream.capabilities.remove(Capabilities::SSL);
- }
-
- // Upgrade to TLS if we were asked to and the server supports it
-
- #[cfg(feature = "_tls-rustls")]
- {
- // To aid in debugging: https://github.com/rustls/rustls/issues/893
-
- let local_addr = stream.local_addr();
-
- match tls::maybe_upgrade(&mut stream, options).await {
- Ok(()) => (),
- #[cfg(feature = "_tls-rustls")]
- Err(Error::Io(ioe)) => {
- if let Some(&rustls::Error::CorruptMessage) =
- ioe.get_ref().and_then(|e| e.downcast_ref())
- {
- log::trace!("got corrupt message on socket {:?}", local_addr);
- }
-
- return Err(Error::Io(ioe));
- }
- Err(e) => return Err(e),
- }
- }
-
- #[cfg(not(feature = "_tls-rustls"))]
- {
- tls::maybe_upgrade(&mut stream, options).await?
- }
+ let mut stream = tls::maybe_upgrade(stream, self.options).await?;
let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) {
Some(plugin.scramble(&mut stream, password, &nonce).await?)
@@ -118,7 +137,7 @@ impl MySqlConnection {
let switch: AuthSwitchRequest = packet.decode()?;
plugin = Some(switch.plugin);
- nonce = switch.data.chain(Bytes::new());
+ let nonce = switch.data.chain(Bytes::new());
let response = switch
.plugin
@@ -140,7 +159,7 @@ impl MySqlConnection {
break;
}
- // plugin signaled to continue authentication
+ // plugin signaled to continue authentication
} else {
return Err(err_protocol!(
"unexpected packet 0x{:02x} during authentication",
@@ -151,11 +170,14 @@ impl MySqlConnection {
}
}
- Ok(Self {
- stream,
- transaction_depth: 0,
- cache_statement: StatementCache::new(options.statement_cache_capacity),
- log_settings: options.log_settings.clone(),
- })
+ Ok(stream)
+ }
+}
+
+impl<'a> WithSocket for DoHandshake<'a> {
+ type Output = BoxFuture<'a, Result>;
+
+ fn with_socket(self, socket: S) -> Self::Output {
+ Box::pin(self.do_handshake(socket))
}
}
diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs
index 1f87eaa918..f35cac8aae 100644
--- a/sqlx-core/src/mysql/connection/mod.rs
+++ b/sqlx-core/src/mysql/connection/mod.rs
@@ -98,7 +98,7 @@ impl Connection for MySqlConnection {
#[doc(hidden)]
fn should_flush(&self) -> bool {
- !self.stream.wbuf.is_empty()
+ !self.stream.write_buffer().is_empty()
}
fn begin(&mut self) -> BoxFuture<'_, Result, Error>>
diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs
index dd9a1235b8..5b058b8e34 100644
--- a/sqlx-core/src/mysql/connection/stream.rs
+++ b/sqlx-core/src/mysql/connection/stream.rs
@@ -4,22 +4,24 @@ use std::ops::{Deref, DerefMut};
use bytes::{Buf, Bytes};
use crate::error::Error;
-use crate::io::{BufStream, Decode, Encode};
+use crate::io::{Decode, Encode};
use crate::mysql::collation::{CharSet, Collation};
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::mysql::protocol::{Capabilities, Packet};
use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
-use crate::net::{MaybeTlsStream, Socket};
+use crate::net::{BufferedSocket, Socket};
-pub struct MySqlStream {
- stream: BufStream>,
+pub struct MySqlStream> {
+ // Wrapping the socket in `Box` allows us to unsize in-place.
+ pub(crate) socket: BufferedSocket,
pub(crate) server_version: (u16, u16, u16),
pub(super) capabilities: Capabilities,
pub(crate) sequence_id: u8,
pub(crate) waiting: VecDeque,
pub(crate) charset: CharSet,
pub(crate) collation: Collation,
+ pub(crate) is_tls: bool,
}
#[derive(Debug, PartialEq, Eq)]
@@ -31,21 +33,13 @@ pub(crate) enum Waiting {
Row,
}
-impl MySqlStream {
- pub(super) async fn connect(options: &MySqlConnectOptions) -> Result {
- let charset: CharSet = options.charset.parse()?;
- let collation: Collation = options
- .collation
- .as_deref()
- .map(|collation| collation.parse())
- .transpose()?
- .unwrap_or_else(|| charset.default_collation());
-
- let socket = match options.socket {
- Some(ref path) => Socket::connect_uds(path).await?,
- None => Socket::connect_tcp(&options.host, options.port).await?,
- };
-
+impl MySqlStream {
+ pub(crate) fn with_socket(
+ charset: CharSet,
+ collation: Collation,
+ options: &MySqlConnectOptions,
+ socket: S,
+ ) -> Self {
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::DEPRECATE_EOF
@@ -63,20 +57,21 @@ impl MySqlStream {
capabilities |= Capabilities::CONNECT_WITH_DB;
}
- Ok(Self {
+ Self {
waiting: VecDeque::new(),
capabilities,
server_version: (0, 0, 0),
sequence_id: 0,
collation,
charset,
- stream: BufStream::new(MaybeTlsStream::Raw(socket)),
- })
+ socket: BufferedSocket::new(socket),
+ is_tls: false,
+ }
}
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
- if !self.stream.wbuf.is_empty() {
- self.stream.flush().await?;
+ if !self.socket.write_buffer().is_empty() {
+ self.socket.flush().await?;
}
while !self.waiting.is_empty() {
@@ -119,14 +114,15 @@ impl MySqlStream {
{
self.sequence_id = 0;
self.write_packet(payload);
- self.flush().await
+ self.flush().await?;
+ Ok(())
}
pub(crate) fn write_packet<'en, T>(&mut self, payload: T)
where
T: Encode<'en, Capabilities>,
{
- self.stream
+ self.socket
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
}
@@ -136,14 +132,14 @@ impl MySqlStream {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
- let mut header: Bytes = self.stream.read(4).await?;
+ let mut header: Bytes = self.socket.read(4).await?;
let packet_size = header.get_uint_le(3) as usize;
let sequence_id = header.get_u8();
self.sequence_id = sequence_id.wrapping_add(1);
- let payload: Bytes = self.stream.read(packet_size).await?;
+ let payload: Bytes = self.socket.read(packet_size).await?;
// TODO: packet compression
// TODO: packet joining
@@ -195,18 +191,31 @@ impl MySqlStream {
Ok(())
}
+
+ pub fn boxed_socket(self) -> MySqlStream {
+ MySqlStream {
+ socket: self.socket.boxed(),
+ server_version: self.server_version,
+ capabilities: self.capabilities,
+ sequence_id: self.sequence_id,
+ waiting: self.waiting,
+ charset: self.charset,
+ collation: self.collation,
+ is_tls: self.is_tls,
+ }
+ }
}
-impl Deref for MySqlStream {
- type Target = BufStream>;
+impl Deref for MySqlStream {
+ type Target = BufferedSocket;
fn deref(&self) -> &Self::Target {
- &self.stream
+ &self.socket
}
}
-impl DerefMut for MySqlStream {
+impl DerefMut for MySqlStream {
fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.stream
+ &mut self.socket
}
}
diff --git a/sqlx-core/src/mysql/connection/tls.rs b/sqlx-core/src/mysql/connection/tls.rs
index 468b638fa8..d1f13792b7 100644
--- a/sqlx-core/src/mysql/connection/tls.rs
+++ b/sqlx-core/src/mysql/connection/tls.rs
@@ -1,39 +1,72 @@
use crate::error::Error;
-use crate::mysql::connection::MySqlStream;
+use crate::mysql::collation::{CharSet, Collation};
+use crate::mysql::connection::{MySqlStream, Waiting};
use crate::mysql::protocol::connect::SslRequest;
use crate::mysql::protocol::Capabilities;
use crate::mysql::{MySqlConnectOptions, MySqlSslMode};
+use crate::net::tls::TlsConfig;
+use crate::net::{tls, BufferedSocket, Socket, WithSocket};
+use std::collections::VecDeque;
-pub(super) async fn maybe_upgrade(
- stream: &mut MySqlStream,
+struct MapStream {
+ server_version: (u16, u16, u16),
+ capabilities: Capabilities,
+ sequence_id: u8,
+ waiting: VecDeque,
+ charset: CharSet,
+ collation: Collation,
+}
+
+pub(super) async fn maybe_upgrade(
+ mut stream: MySqlStream,
options: &MySqlConnectOptions,
-) -> Result<(), Error> {
+) -> Result {
+ let server_supports_tls = stream.capabilities.contains(Capabilities::SSL);
+
+ if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() {
+ // remove the SSL capability if SSL has been explicitly disabled
+ stream.capabilities.remove(Capabilities::SSL);
+ }
+
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match options.ssl_mode {
- MySqlSslMode::Disabled => {}
+ MySqlSslMode::Disabled => return Ok(stream.boxed_socket()),
MySqlSslMode::Preferred => {
- // try upgrade, but its okay if we fail
- upgrade(stream, options).await?;
+ if !tls::available() {
+ // Client doesn't support TLS
+ log::debug!("not performing TLS upgrade: TLS support not compiled in");
+ return Ok(stream.boxed_socket());
+ }
+
+ if !server_supports_tls {
+ // Server doesn't support TLS
+ log::debug!("not performing TLS upgrade: unsupported by server");
+ return Ok(stream.boxed_socket());
+ }
}
MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => {
- if !upgrade(stream, options).await? {
+ tls::error_if_unavailable()?;
+
+ if !server_supports_tls {
// upgrade failed, die
return Err(Error::Tls("server does not support TLS".into()));
}
}
}
- Ok(())
-}
-
-async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Result {
- if !stream.capabilities.contains(Capabilities::SSL) {
- // server does not support TLS
- return Ok(false);
- }
+ let tls_config = TlsConfig {
+ accept_invalid_certs: !matches!(
+ options.ssl_mode,
+ MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
+ ),
+ accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
+ hostname: &options.host,
+ root_cert_path: options.ssl_ca.as_ref(),
+ };
+ // Request TLS upgrade
stream.write_packet(SslRequest {
max_packet_size: super::MAX_PACKET_SIZE,
collation: stream.collation as u8,
@@ -41,20 +74,34 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res
stream.flush().await?;
- let accept_invalid_certs = !matches!(
- options.ssl_mode,
- MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
- );
- let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity);
-
- stream
- .upgrade(
- &options.host,
- accept_invalid_certs,
- accept_invalid_host_names,
- options.ssl_ca.as_ref(),
- )
- .await?;
-
- Ok(true)
+ tls::handshake(
+ stream.socket.into_inner(),
+ tls_config,
+ MapStream {
+ server_version: stream.server_version,
+ capabilities: stream.capabilities,
+ sequence_id: stream.sequence_id,
+ waiting: stream.waiting,
+ charset: stream.charset,
+ collation: stream.collation,
+ },
+ )
+ .await
+}
+
+impl WithSocket for MapStream {
+ type Output = MySqlStream;
+
+ fn with_socket(self, socket: S) -> Self::Output {
+ MySqlStream {
+ socket: BufferedSocket::new(Box::new(socket)),
+ server_version: self.server_version,
+ capabilities: self.capabilities,
+ sequence_id: self.sequence_id,
+ waiting: self.waiting,
+ charset: self.charset,
+ collation: self.collation,
+ is_tls: true,
+ }
+ }
}
diff --git a/sqlx-core/src/mysql/options/mod.rs b/sqlx-core/src/mysql/options/mod.rs
index 5d152c3869..d0959579d1 100644
--- a/sqlx-core/src/mysql/options/mod.rs
+++ b/sqlx-core/src/mysql/options/mod.rs
@@ -4,7 +4,7 @@ mod connect;
mod parse;
mod ssl_mode;
-use crate::{connection::LogSettings, net::CertificateInput};
+use crate::{connection::LogSettings, net::tls::CertificateInput};
pub use ssl_mode::MySqlSslMode;
/// Options and flags which can be used to configure a MySQL connection.
@@ -35,8 +35,8 @@ pub use ssl_mode::MySqlSslMode;
/// # use sqlx_core::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode};
/// #
/// # fn main() {
-/// # #[cfg(feature = "_rt-async-std")]
-/// # sqlx_rt::async_std::task::block_on::<_, Result<(), Error>>(async move {
+/// # #[cfg(feature = "_rt")]
+/// # sqlx::__rt::test_block_on(async move {
/// // URL connection string
/// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?;
///
@@ -47,7 +47,7 @@ pub use ssl_mode::MySqlSslMode;
/// .password("password")
/// .database("db")
/// .connect().await?;
-/// # Ok(())
+/// # Result::<(), Error>::Ok(())
/// # }).unwrap();
/// # }
/// ```
diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs
index 429c5f6c44..3c75f32c92 100644
--- a/sqlx-core/src/net/mod.rs
+++ b/sqlx-core/src/net/mod.rs
@@ -1,17 +1,4 @@
mod socket;
-mod tls;
+pub mod tls;
-pub use socket::Socket;
-pub use tls::{CertificateInput, MaybeTlsStream};
-
-#[cfg(feature = "_rt-async-std")]
-type PollReadBuf<'a> = [u8];
-
-#[cfg(feature = "_rt-tokio")]
-type PollReadBuf<'a> = sqlx_rt::ReadBuf<'a>;
-
-#[cfg(feature = "_rt-async-std")]
-type PollReadOut = usize;
-
-#[cfg(feature = "_rt-tokio")]
-type PollReadOut = ();
+pub use socket::{connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket};
diff --git a/sqlx-core/src/net/socket.rs b/sqlx-core/src/net/socket.rs
deleted file mode 100644
index 622a1a22ce..0000000000
--- a/sqlx-core/src/net/socket.rs
+++ /dev/null
@@ -1,134 +0,0 @@
-#![allow(dead_code)]
-
-use std::io;
-use std::net::SocketAddr;
-use std::path::Path;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-use sqlx_rt::{AsyncRead, AsyncWrite, TcpStream};
-
-#[derive(Debug)]
-pub enum Socket {
- Tcp(TcpStream),
-
- #[cfg(unix)]
- Unix(sqlx_rt::UnixStream),
-}
-
-impl Socket {
- pub async fn connect_tcp(host: &str, port: u16) -> io::Result {
- // Trim square brackets from host if it's an IPv6 address as the `url` crate doesn't do that.
- TcpStream::connect((host.trim_matches(|c| c == '[' || c == ']'), port))
- .await
- .map(Socket::Tcp)
- }
-
- #[cfg(unix)]
- pub async fn connect_uds(path: impl AsRef) -> io::Result {
- sqlx_rt::UnixStream::connect(path.as_ref())
- .await
- .map(Socket::Unix)
- }
-
- pub fn local_addr(&self) -> Option {
- match self {
- Self::Tcp(tcp) => tcp.local_addr().ok(),
- #[cfg(unix)]
- Self::Unix(_) => None,
- }
- }
-
- #[cfg(not(unix))]
- pub async fn connect_uds(_: impl AsRef) -> io::Result {
- Err(io::Error::new(
- io::ErrorKind::Other,
- "Unix domain sockets are not supported outside Unix platforms.",
- ))
- }
-
- pub async fn shutdown(&mut self) -> io::Result<()> {
- #[cfg(feature = "_rt-async-std")]
- {
- use std::net::Shutdown;
-
- match self {
- Socket::Tcp(s) => s.shutdown(Shutdown::Both),
-
- #[cfg(unix)]
- Socket::Unix(s) => s.shutdown(Shutdown::Both),
- }
- }
-
- #[cfg(feature = "_rt-tokio")]
- {
- use sqlx_rt::AsyncWriteExt;
-
- match self {
- Socket::Tcp(s) => s.shutdown().await,
-
- #[cfg(unix)]
- Socket::Unix(s) => s.shutdown().await,
- }
- }
- }
-}
-
-impl AsyncRead for Socket {
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut super::PollReadBuf<'_>,
- ) -> Poll> {
- match &mut *self {
- Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf),
-
- #[cfg(unix)]
- Socket::Unix(s) => Pin::new(s).poll_read(cx, buf),
- }
- }
-}
-
-impl AsyncWrite for Socket {
- fn poll_write(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll> {
- match &mut *self {
- Socket::Tcp(s) => Pin::new(s).poll_write(cx, buf),
-
- #[cfg(unix)]
- Socket::Unix(s) => Pin::new(s).poll_write(cx, buf),
- }
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- Socket::Tcp(s) => Pin::new(s).poll_flush(cx),
-
- #[cfg(unix)]
- Socket::Unix(s) => Pin::new(s).poll_flush(cx),
- }
- }
-
- #[cfg(feature = "_rt-tokio")]
- fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- Socket::Tcp(s) => Pin::new(s).poll_shutdown(cx),
-
- #[cfg(unix)]
- Socket::Unix(s) => Pin::new(s).poll_shutdown(cx),
- }
- }
-
- #[cfg(feature = "_rt-async-std")]
- fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- Socket::Tcp(s) => Pin::new(s).poll_close(cx),
-
- #[cfg(unix)]
- Socket::Unix(s) => Pin::new(s).poll_close(cx),
- }
- }
-}
diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs
new file mode 100644
index 0000000000..dc05c87863
--- /dev/null
+++ b/sqlx-core/src/net/socket/buffered.rs
@@ -0,0 +1,234 @@
+use crate::net::Socket;
+use bytes::BytesMut;
+use std::io;
+
+use crate::error::Error;
+
+use crate::io::{Decode, Encode};
+
+// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
+const DEFAULT_BUF_SIZE: usize = 8192;
+
+pub struct BufferedSocket {
+ socket: S,
+ write_buf: WriteBuffer,
+ read_buf: ReadBuffer,
+}
+
+pub struct WriteBuffer {
+ buf: Vec,
+ bytes_written: usize,
+ bytes_flushed: usize,
+}
+
+pub struct ReadBuffer {
+ read: BytesMut,
+ available: BytesMut,
+}
+
+impl BufferedSocket {
+ pub fn new(socket: S) -> Self
+ where
+ S: Sized,
+ {
+ BufferedSocket {
+ socket,
+ write_buf: WriteBuffer {
+ buf: Vec::with_capacity(DEFAULT_BUF_SIZE),
+ bytes_written: 0,
+ bytes_flushed: 0,
+ },
+ read_buf: ReadBuffer {
+ read: BytesMut::new(),
+ available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
+ },
+ }
+ }
+
+ pub async fn read_buffered(&mut self, len: usize) -> io::Result {
+ while self.read_buf.read.len() < len {
+ self.read_buf.reserve(len);
+
+ let read = self.socket.read(&mut self.read_buf.available).await?;
+
+ if read == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ format!(
+ "expected to read {} bytes, got {} bytes at EOF",
+ len,
+ self.read_buf.read.len()
+ ),
+ ));
+ }
+
+ self.read_buf.advance(read);
+ }
+
+ Ok(self.read_buf.drain(len))
+ }
+
+ pub fn write_buffer(&self) -> &WriteBuffer {
+ &self.write_buf
+ }
+
+ pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer {
+ &mut self.write_buf
+ }
+
+ pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result
+ where
+ T: Decode<'de, ()>,
+ {
+ self.read_with(byte_len, ()).await
+ }
+
+ pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result
+ where
+ T: Decode<'de, C>,
+ {
+ T::decode_with(self.read_buffered(byte_len).await?.freeze(), context)
+ }
+
+ pub fn write<'en, T>(&mut self, value: T)
+ where
+ T: Encode<'en, ()>,
+ {
+ self.write_with(value, ())
+ }
+
+ pub fn write_with<'en, T, C>(&mut self, value: T, context: C)
+ where
+ T: Encode<'en, C>,
+ {
+ value.encode_with(self.write_buf.buf_mut(), context);
+ self.write_buf.bytes_written = self.write_buf.buf.len();
+ self.write_buf.sanity_check();
+ }
+
+ pub async fn flush(&mut self) -> io::Result<()> {
+ while !self.write_buf.is_empty() {
+ let written = self.socket.write(self.write_buf.get()).await?;
+ self.write_buf.consume(written);
+ self.write_buf.sanity_check();
+ }
+
+ self.socket.flush().await?;
+
+ Ok(())
+ }
+
+ pub async fn shutdown(&mut self) -> io::Result<()> {
+ self.flush().await?;
+ self.socket.shutdown().await
+ }
+
+ pub fn into_inner(self) -> S {
+ self.socket
+ }
+
+ pub fn boxed(self) -> BufferedSocket> {
+ BufferedSocket {
+ socket: Box::new(self.socket),
+ write_buf: self.write_buf,
+ read_buf: self.read_buf,
+ }
+ }
+}
+
+impl WriteBuffer {
+ fn sanity_check(&self) {
+ assert_ne!(self.buf.capacity(), 0);
+ assert!(self.bytes_written <= self.buf.len());
+ assert!(self.bytes_flushed <= self.bytes_written);
+ }
+
+ pub fn buf_mut(&mut self) -> &mut Vec {
+ self.buf.truncate(self.bytes_written);
+ self.sanity_check();
+ &mut self.buf
+ }
+
+ pub fn init_remaining_mut(&mut self) -> &mut [u8] {
+ self.buf.resize(self.buf.capacity(), 0);
+ self.sanity_check();
+ &mut self.buf[self.bytes_written..]
+ }
+
+ pub fn put_slice(&mut self, slice: &[u8]) {
+ // If we already have an initialized area that can fit the slice,
+ // don't change `self.buf.len()`
+ if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) {
+ dest.copy_from_slice(slice);
+ } else {
+ self.buf.truncate(self.bytes_written);
+ self.buf.extend_from_slice(slice);
+ }
+
+ self.sanity_check();
+ }
+
+ pub fn advance(&mut self, amt: usize) {
+ let new_bytes_written = self
+ .bytes_written
+ .checked_add(amt)
+ .expect("self.bytes_written + amt overflowed");
+
+ assert!(new_bytes_written <= self.buf.len());
+
+ self.bytes_written = new_bytes_written;
+
+ self.sanity_check();
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.bytes_flushed >= self.bytes_written
+ }
+
+ pub fn is_full(&self) -> bool {
+ self.bytes_written == self.buf.len()
+ }
+
+ pub fn get(&self) -> &[u8] {
+ &self.buf[self.bytes_flushed..self.bytes_written]
+ }
+
+ pub fn get_mut(&mut self) -> &mut [u8] {
+ &mut self.buf[self.bytes_flushed..self.bytes_written]
+ }
+
+ fn consume(&mut self, amt: usize) {
+ let new_bytes_flushed = self
+ .bytes_flushed
+ .checked_add(amt)
+ .expect("self.bytes_flushed + amt overflowed");
+
+ assert!(new_bytes_flushed <= self.bytes_written);
+
+ self.bytes_flushed = new_bytes_flushed;
+
+ if self.bytes_flushed == self.bytes_written {
+ // Reset cursors to zero if we've consumed the whole buffer
+ self.bytes_flushed = 0;
+ self.bytes_written = 0;
+ }
+
+ self.sanity_check();
+ }
+}
+
+impl ReadBuffer {
+ fn reserve(&mut self, amt: usize) {
+ if let Some(additional) = amt.checked_sub(self.available.capacity()) {
+ self.available.reserve(additional);
+ }
+ }
+
+ fn advance(&mut self, amt: usize) {
+ self.read.unsplit(self.available.split_to(amt));
+ }
+
+ fn drain(&mut self, amt: usize) -> BytesMut {
+ self.read.split_to(amt)
+ }
+}
diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs
new file mode 100644
index 0000000000..cd7f24d780
--- /dev/null
+++ b/sqlx-core/src/net/socket/mod.rs
@@ -0,0 +1,259 @@
+use std::future::Future;
+use std::io;
+use std::path::Path;
+use std::pin::Pin;
+use std::task::{Context, Poll};
+
+use bytes::BufMut;
+use futures_core::ready;
+
+pub use buffered::{BufferedSocket, WriteBuffer};
+
+use crate::io::ReadBuf;
+
+mod buffered;
+
+pub trait Socket: Send + Sync + Unpin + 'static {
+ fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result;
+
+ fn try_write(&mut self, buf: &[u8]) -> io::Result;
+
+ fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll>;
+
+ fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll>;
+
+ fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll> {
+ // `flush()` is a no-op for TCP/UDS
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll>;
+
+ fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
+ where
+ Self: Sized,
+ {
+ Read { socket: self, buf }
+ }
+
+ fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
+ where
+ Self: Sized,
+ {
+ Write { socket: self, buf }
+ }
+
+ fn flush(&mut self) -> Flush<'_, Self>
+ where
+ Self: Sized,
+ {
+ Flush { socket: self }
+ }
+
+ fn shutdown(&mut self) -> Shutdown<'_, Self>
+ where
+ Self: Sized,
+ {
+ Shutdown { socket: self }
+ }
+}
+
+pub struct Read<'a, S: ?Sized, B> {
+ socket: &'a mut S,
+ buf: &'a mut B,
+}
+
+impl<'a, S: ?Sized, B> Future for Read<'a, S, B>
+where
+ S: Socket,
+ B: ReadBuf,
+{
+ type Output = io::Result;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+ let this = &mut *self;
+
+ while this.buf.has_remaining_mut() {
+ match this.socket.try_read(&mut *this.buf) {
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ ready!(this.socket.poll_read_ready(cx))?;
+ }
+ ready => return Poll::Ready(ready),
+ }
+ }
+
+ Poll::Ready(Ok(0))
+ }
+}
+
+pub struct Write<'a, S: ?Sized> {
+ socket: &'a mut S,
+ buf: &'a [u8],
+}
+
+impl<'a, S: ?Sized> Future for Write<'a, S>
+where
+ S: Socket,
+{
+ type Output = io::Result;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+ let this = &mut *self;
+
+ while !this.buf.is_empty() {
+ match this.socket.try_write(&mut this.buf) {
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ ready!(this.socket.poll_write_ready(cx))?;
+ }
+ ready => return Poll::Ready(ready),
+ }
+ }
+
+ Poll::Ready(Ok(0))
+ }
+}
+
+pub struct Flush<'a, S: ?Sized> {
+ socket: &'a mut S,
+}
+
+impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> {
+ type Output = io::Result<()>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+ self.socket.poll_flush(cx)
+ }
+}
+
+pub struct Shutdown<'a, S: ?Sized> {
+ socket: &'a mut S,
+}
+
+impl<'a, S: ?Sized> Future for Shutdown<'a, S>
+where
+ S: Socket,
+{
+ type Output = io::Result<()>;
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll {
+ self.socket.poll_shutdown(cx)
+ }
+}
+
+pub trait WithSocket {
+ type Output;
+
+ fn with_socket(self, socket: S) -> Self::Output;
+}
+
+pub struct SocketIntoBox;
+
+impl WithSocket for SocketIntoBox {
+ type Output = Box;
+
+ fn with_socket(self, socket: S) -> Self::Output {
+ Box::new(socket)
+ }
+}
+
+impl Socket for Box {
+ fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result {
+ (**self).try_read(buf)
+ }
+
+ fn try_write(&mut self, buf: &[u8]) -> io::Result {
+ (**self).try_write(buf)
+ }
+
+ fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ (**self).poll_read_ready(cx)
+ }
+
+ fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ (**self).poll_write_ready(cx)
+ }
+
+ fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> {
+ (**self).poll_shutdown(cx)
+ }
+}
+
+pub async fn connect_tcp(
+ host: &str,
+ port: u16,
+ with_socket: Ws,
+) -> crate::Result {
+ // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those.
+ let host = host.trim_matches(&['[', ']'][..]);
+
+ #[cfg(feature = "_rt-tokio")]
+ if crate::rt::rt_tokio::available() {
+ use tokio::net::TcpStream;
+
+ let stream = TcpStream::connect((host, port)).await?;
+
+ return Ok(with_socket.with_socket(stream));
+ }
+
+ #[cfg(feature = "_rt-async-std")]
+ {
+ use async_io::Async;
+ use async_std::net::ToSocketAddrs;
+ use std::net::TcpStream;
+
+ let socket_addr = (host, port)
+ .to_socket_addrs()
+ .await?
+ .next()
+ .expect("BUG: to_socket_addrs() should have returned at least one result");
+
+ let stream = Async::::connect(socket_addr).await?;
+
+ return Ok(with_socket.with_socket(stream));
+ }
+
+ #[cfg(not(feature = "_rt-async-std"))]
+ {
+ crate::rt::missing_rt((host, port, with_socket))
+ }
+}
+
+/// Connect a Unix Domain Socket at the given path.
+///
+/// Returns an error if Unix Domain Sockets are not supported on this platform.
+pub async fn connect_uds, Ws: WithSocket>(
+ path: P,
+ with_socket: Ws,
+) -> crate::Result {
+ if cfg!(not(unix)) {
+ return Err(io::Error::new(
+ io::ErrorKind::Unsupported,
+ "Unix domain sockets are not supported on this platform",
+ )
+ .into());
+ }
+
+ #[cfg(all(unix, feature = "_rt-tokio"))]
+ if crate::rt::rt_tokio::available() {
+ use tokio::net::UnixStream;
+
+ let stream = UnixStream::connect(path).await?;
+
+ return Ok(with_socket.with_socket(stream));
+ }
+
+ #[cfg(all(unix, feature = "_rt-async-std"))]
+ {
+ use async_io::Async;
+ use std::os::unix::net::UnixStream;
+
+ let stream = Async::::connect(path).await?;
+
+ return Ok(with_socket.with_socket(stream));
+ }
+
+ #[cfg(not(feature = "_rt-async-std"))]
+ {
+ crate::rt::missing_rt((path, with_socket))
+ }
+}
diff --git a/sqlx-core/src/net/tls/mod.rs b/sqlx-core/src/net/tls/mod.rs
index 85e5dda7c1..3fae0ecaed 100644
--- a/sqlx-core/src/net/tls/mod.rs
+++ b/sqlx-core/src/net/tls/mod.rs
@@ -1,15 +1,18 @@
#![allow(dead_code)]
-use std::io;
-use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
-use std::pin::Pin;
-use std::task::{Context, Poll};
-
-use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
use crate::error::Error;
-use std::mem::replace;
+use crate::net::socket::WithSocket;
+use crate::net::Socket;
+
+#[cfg(feature = "_tls-rustls")]
+mod tls_rustls;
+
+#[cfg(feature = "_tls-native-tls")]
+mod tls_native_tls;
+
+mod util;
/// X.509 Certificate input, either a file path or a PEM encoded inline certificate(s).
#[derive(Clone, Debug)]
@@ -36,7 +39,7 @@ impl From for CertificateInput {
impl CertificateInput {
async fn data(&self) -> Result, std::io::Error> {
- use sqlx_rt::fs;
+ use crate::fs;
match self {
CertificateInput::Inline(v) => Ok(v.clone()),
CertificateInput::File(path) => fs::read(path).await,
@@ -53,210 +56,46 @@ impl std::fmt::Display for CertificateInput {
}
}
-#[cfg(feature = "_tls-rustls")]
-mod rustls;
-
-pub enum MaybeTlsStream
-where
- S: AsyncRead + AsyncWrite + Unpin,
-{
- Raw(S),
- Tls(TlsStream),
- Upgrading,
+pub struct TlsConfig<'a> {
+ pub accept_invalid_certs: bool,
+ pub accept_invalid_hostnames: bool,
+ pub hostname: &'a str,
+ pub root_cert_path: Option<&'a CertificateInput>,
}
-impl MaybeTlsStream
+pub async fn handshake(
+ socket: S,
+ config: TlsConfig<'_>,
+ with_socket: Ws,
+) -> crate::Result
where
- S: AsyncRead + AsyncWrite + Unpin,
+ S: Socket,
+ Ws: WithSocket,
{
- #[inline]
- pub fn is_tls(&self) -> bool {
- matches!(self, Self::Tls(_))
- }
-
- pub async fn upgrade(
- &mut self,
- host: &str,
- accept_invalid_certs: bool,
- accept_invalid_hostnames: bool,
- root_cert_path: Option<&CertificateInput>,
- ) -> Result<(), Error> {
- let connector = configure_tls_connector(
- accept_invalid_certs,
- accept_invalid_hostnames,
- root_cert_path,
- )
- .await?;
-
- let stream = match replace(self, MaybeTlsStream::Upgrading) {
- MaybeTlsStream::Raw(stream) => stream,
-
- MaybeTlsStream::Tls(_) => {
- // ignore upgrade, we are already a TLS connection
- return Ok(());
- }
-
- MaybeTlsStream::Upgrading => {
- // we previously failed to upgrade and now hold no connection
- // this should only happen from an internal misuse of this method
- return Err(Error::Io(io::ErrorKind::ConnectionAborted.into()));
- }
- };
-
- #[cfg(feature = "_tls-rustls")]
- let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?;
-
- *self = MaybeTlsStream::Tls(connector.connect(host, stream).await?);
-
- Ok(())
- }
-}
-
-#[cfg(feature = "_tls-native-tls")]
-async fn configure_tls_connector(
- accept_invalid_certs: bool,
- accept_invalid_hostnames: bool,
- root_cert_path: Option<&CertificateInput>,
-) -> Result {
- use sqlx_rt::native_tls::{Certificate, TlsConnector};
-
- let mut builder = TlsConnector::builder();
- builder
- .danger_accept_invalid_certs(accept_invalid_certs)
- .danger_accept_invalid_hostnames(accept_invalid_hostnames);
-
- if !accept_invalid_certs {
- if let Some(ca) = root_cert_path {
- let data = ca.data().await?;
- let cert = Certificate::from_pem(&data)?;
-
- builder.add_root_certificate(cert);
- }
- }
+ #[cfg(feature = "_tls-native-tls")]
+ return Ok(with_socket.with_socket(tls_native_tls::handshake(socket, config).await?));
- #[cfg(not(feature = "_rt-async-std"))]
- let connector = builder.build()?.into();
+ #[cfg(feature = "_tls-rustls")]
+ return Ok(with_socket.with_socket(tls_rustls::handshake(socket, config).await?));
- #[cfg(feature = "_rt-async-std")]
- let connector = builder.into();
-
- Ok(connector)
-}
-
-#[cfg(feature = "_tls-rustls")]
-use self::rustls::configure_tls_connector;
-
-impl AsyncRead for MaybeTlsStream
-where
- S: Unpin + AsyncWrite + AsyncRead,
-{
- fn poll_read(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &mut super::PollReadBuf<'_>,
- ) -> Poll> {
- match &mut *self {
- MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
- MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
-
- MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
- }
+ #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))]
+ {
+ drop((socket, config, with_socket));
+ panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled")
}
}
-impl AsyncWrite for MaybeTlsStream
-where
- S: Unpin + AsyncWrite + AsyncRead,
-{
- fn poll_write(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- buf: &[u8],
- ) -> Poll> {
- match &mut *self {
- MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
- MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
-
- MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
- }
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
- MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
-
- MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
- }
- }
-
- #[cfg(feature = "_rt-tokio")]
- fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
- MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
-
- MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
- }
- }
-
- #[cfg(feature = "_rt-async-std")]
- fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
- match &mut *self {
- MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
- MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
-
- MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
- }
- }
+pub fn available() -> bool {
+ cfg!(any(feature = "_tls-native-tls", feature = "_tls-rustls"))
}
-impl Deref for MaybeTlsStream
-where
- S: Unpin + AsyncWrite + AsyncRead,
-{
- type Target = S;
-
- fn deref(&self) -> &Self::Target {
- match self {
- MaybeTlsStream::Raw(s) => s,
-
- #[cfg(feature = "_tls-rustls")]
- MaybeTlsStream::Tls(s) => s.get_ref().0,
-
- #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
- MaybeTlsStream::Tls(s) => s.get_ref(),
-
- #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
- MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(),
-
- MaybeTlsStream::Upgrading => {
- panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
- }
- }
+pub fn error_if_unavailable() -> crate::Result<()> {
+ if !available() {
+ return Err(Error::tls(
+ "TLS upgrade required by connect options \
+ but SQLx was built without TLS support enabled",
+ ));
}
-}
-impl DerefMut for MaybeTlsStream
-where
- S: Unpin + AsyncWrite + AsyncRead,
-{
- fn deref_mut(&mut self) -> &mut Self::Target {
- match self {
- MaybeTlsStream::Raw(s) => s,
-
- #[cfg(feature = "_tls-rustls")]
- MaybeTlsStream::Tls(s) => s.get_mut().0,
-
- #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))]
- MaybeTlsStream::Tls(s) => s.get_mut(),
-
- #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))]
- MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(),
-
- MaybeTlsStream::Upgrading => {
- panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted))
- }
- }
- }
+ Ok(())
}
diff --git a/sqlx-core/src/net/tls/rustls.rs b/sqlx-core/src/net/tls/rustls.rs
deleted file mode 100644
index 2ad958b0d2..0000000000
--- a/sqlx-core/src/net/tls/rustls.rs
+++ /dev/null
@@ -1,108 +0,0 @@
-use crate::net::CertificateInput;
-use rustls::{
- client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier},
- ClientConfig, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName,
-};
-use std::io::Cursor;
-use std::sync::Arc;
-use std::time::SystemTime;
-
-use crate::error::Error;
-
-pub async fn configure_tls_connector(
- accept_invalid_certs: bool,
- accept_invalid_hostnames: bool,
- root_cert_path: Option<&CertificateInput>,
-) -> Result {
- let config = ClientConfig::builder().with_safe_defaults();
-
- let config = if accept_invalid_certs {
- config
- .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
- .with_no_client_auth()
- } else {
- let mut cert_store = RootCertStore::empty();
- cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
- OwnedTrustAnchor::from_subject_spki_name_constraints(
- ta.subject,
- ta.spki,
- ta.name_constraints,
- )
- }));
-
- if let Some(ca) = root_cert_path {
- let data = ca.data().await?;
- let mut cursor = Cursor::new(data);
-
- for cert in rustls_pemfile::certs(&mut cursor)
- .map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
- {
- cert_store
- .add(&rustls::Certificate(cert))
- .map_err(|err| Error::Tls(err.into()))?;
- }
- }
-
- if accept_invalid_hostnames {
- let verifier = WebPkiVerifier::new(cert_store, None);
-
- config
- .with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
- .with_no_client_auth()
- } else {
- config
- .with_root_certificates(cert_store)
- .with_no_client_auth()
- }
- };
-
- Ok(Arc::new(config).into())
-}
-
-struct DummyTlsVerifier;
-
-impl ServerCertVerifier for DummyTlsVerifier {
- fn verify_server_cert(
- &self,
- _end_entity: &rustls::Certificate,
- _intermediates: &[rustls::Certificate],
- _server_name: &ServerName,
- _scts: &mut dyn Iterator- ,
- _ocsp_response: &[u8],
- _now: SystemTime,
- ) -> Result {
- Ok(ServerCertVerified::assertion())
- }
-}
-
-pub struct NoHostnameTlsVerifier {
- verifier: WebPkiVerifier,
-}
-
-impl ServerCertVerifier for NoHostnameTlsVerifier {
- fn verify_server_cert(
- &self,
- end_entity: &rustls::Certificate,
- intermediates: &[rustls::Certificate],
- server_name: &ServerName,
- scts: &mut dyn Iterator
- ,
- ocsp_response: &[u8],
- now: SystemTime,
- ) -> Result {
- match self.verifier.verify_server_cert(
- end_entity,
- intermediates,
- server_name,
- scts,
- ocsp_response,
- now,
- ) {
- Err(TlsError::InvalidCertificateData(reason))
- if reason.contains("CertNotValidForName") =>
- {
- Ok(ServerCertVerified::assertion())
- }
- res => res,
- }
- }
-}
diff --git a/sqlx-core/src/net/tls/tls_native_tls.rs b/sqlx-core/src/net/tls/tls_native_tls.rs
new file mode 100644
index 0000000000..5405bac3c2
--- /dev/null
+++ b/sqlx-core/src/net/tls/tls_native_tls.rs
@@ -0,0 +1,82 @@
+use std::io::{self, Read, Write};
+
+use crate::io::ReadBuf;
+use crate::net::tls::util::StdSocket;
+use crate::net::tls::TlsConfig;
+use crate::net::Socket;
+use crate::Error;
+
+use native_tls::HandshakeError;
+use std::task::{Context, Poll};
+
+pub struct NativeTlsSocket {
+ stream: native_tls::TlsStream>,
+}
+
+impl Socket for NativeTlsSocket
{
+ fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result {
+ self.stream.read(buf.init_mut())
+ }
+
+ fn try_write(&mut self, buf: &[u8]) -> io::Result {
+ self.stream.write(buf)
+ }
+
+ fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.stream.get_mut().poll_ready(cx)
+ }
+
+ fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.stream.get_mut().poll_ready(cx)
+ }
+
+ fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> {
+ match self.stream.shutdown() {
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => self.stream.get_mut().poll_ready(cx),
+ ready => Poll::Ready(ready),
+ }
+ }
+}
+
+/// DEPRECATED: this should never have been public.
+impl From for Error {
+ fn from(e: native_tls::Error) -> Self {
+ Error::Tls(Box::new(e))
+ }
+}
+
+pub async fn handshake(
+ socket: S,
+ config: TlsConfig<'_>,
+) -> crate::Result> {
+ let mut builder = native_tls::TlsConnector::builder();
+
+ builder
+ .danger_accept_invalid_certs(config.accept_invalid_certs)
+ .danger_accept_invalid_hostnames(config.accept_invalid_hostnames);
+
+ if let Some(root_cert_path) = config.root_cert_path {
+ let data = root_cert_path.data().await?;
+ builder.add_root_certificate(native_tls::Certificate::from_pem(&data)?);
+ }
+
+ let connector = builder.build()?;
+
+ let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
+ Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
+ Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
+ Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,
+ };
+
+ loop {
+ mid_handshake.get_mut().ready().await?;
+
+ match mid_handshake.handshake() {
+ Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
+ Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
+ Err(HandshakeError::WouldBlock(mid_handshake_)) => {
+ mid_handshake = mid_handshake_;
+ }
+ }
+ }
+}
diff --git a/sqlx-core/src/net/tls/tls_rustls.rs b/sqlx-core/src/net/tls/tls_rustls.rs
new file mode 100644
index 0000000000..230e03527f
--- /dev/null
+++ b/sqlx-core/src/net/tls/tls_rustls.rs
@@ -0,0 +1,184 @@
+use futures_util::future;
+use std::io;
+use std::io::{Cursor, Read, Write};
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::time::SystemTime;
+
+use rustls::{
+ client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier},
+ ClientConfig, ClientConnection, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName,
+};
+
+use crate::error::Error;
+use crate::io::ReadBuf;
+use crate::net::tls::util::StdSocket;
+use crate::net::tls::TlsConfig;
+use crate::net::Socket;
+
+pub struct RustlsSocket {
+ inner: StdSocket,
+ state: ClientConnection,
+ close_notify_sent: bool,
+}
+
+impl RustlsSocket {
+ fn poll_complete_io(&mut self, cx: &mut Context<'_>) -> Poll> {
+ loop {
+ match self.state.complete_io(&mut self.inner) {
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ futures_util::ready!(self.inner.poll_ready(cx))?;
+ }
+ ready => return Poll::Ready(ready.map(|_| ())),
+ }
+ }
+ }
+
+ async fn complete_io(&mut self) -> io::Result<()> {
+ future::poll_fn(|cx| self.poll_complete_io(cx)).await
+ }
+}
+
+impl Socket for RustlsSocket {
+ fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result {
+ self.state.reader().read(buf.init_mut())
+ }
+
+ fn try_write(&mut self, buf: &[u8]) -> io::Result {
+ match self.state.writer().write(buf) {
+ // Returns a zero-length write when the buffer is full.
+ Ok(0) => Err(io::ErrorKind::WouldBlock.into()),
+ other => return other,
+ }
+ }
+
+ fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.poll_complete_io(cx)
+ }
+
+ fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.poll_complete_io(cx)
+ }
+
+ fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.poll_complete_io(cx)
+ }
+
+ fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> {
+ if !self.close_notify_sent {
+ self.state.send_close_notify();
+ self.close_notify_sent = true;
+ }
+
+ futures_util::ready!(self.poll_complete_io(cx))?;
+ self.inner.socket.poll_shutdown(cx)
+ }
+}
+
+pub async fn handshake(socket: S, tls_config: TlsConfig<'_>) -> Result, Error>
+where
+ S: Socket,
+{
+ let config = ClientConfig::builder().with_safe_defaults();
+
+ let config = if tls_config.accept_invalid_certs {
+ config
+ .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier))
+ .with_no_client_auth()
+ } else {
+ let mut cert_store = RootCertStore::empty();
+ cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
+ OwnedTrustAnchor::from_subject_spki_name_constraints(
+ ta.subject,
+ ta.spki,
+ ta.name_constraints,
+ )
+ }));
+
+ if let Some(ca) = tls_config.root_cert_path {
+ let data = ca.data().await?;
+ let mut cursor = Cursor::new(data);
+
+ for cert in rustls_pemfile::certs(&mut cursor)
+ .map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))?
+ {
+ cert_store
+ .add(&rustls::Certificate(cert))
+ .map_err(|err| Error::Tls(err.into()))?;
+ }
+ }
+
+ if tls_config.accept_invalid_hostnames {
+ let verifier = WebPkiVerifier::new(cert_store, None);
+
+ config
+ .with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
+ .with_no_client_auth()
+ } else {
+ config
+ .with_root_certificates(cert_store)
+ .with_no_client_auth()
+ }
+ };
+
+ let host = rustls::ServerName::try_from(tls_config.hostname).map_err(Error::tls)?;
+
+ let mut socket = RustlsSocket {
+ inner: StdSocket::new(socket),
+ state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?,
+ close_notify_sent: false,
+ };
+
+ // Performs the TLS handshake or bails
+ socket.complete_io().await?;
+
+ Ok(socket)
+}
+
+struct DummyTlsVerifier;
+
+impl ServerCertVerifier for DummyTlsVerifier {
+ fn verify_server_cert(
+ &self,
+ _end_entity: &rustls::Certificate,
+ _intermediates: &[rustls::Certificate],
+ _server_name: &ServerName,
+ _scts: &mut dyn Iterator- ,
+ _ocsp_response: &[u8],
+ _now: SystemTime,
+ ) -> Result {
+ Ok(ServerCertVerified::assertion())
+ }
+}
+
+pub struct NoHostnameTlsVerifier {
+ verifier: WebPkiVerifier,
+}
+
+impl ServerCertVerifier for NoHostnameTlsVerifier {
+ fn verify_server_cert(
+ &self,
+ end_entity: &rustls::Certificate,
+ intermediates: &[rustls::Certificate],
+ server_name: &ServerName,
+ scts: &mut dyn Iterator
- ,
+ ocsp_response: &[u8],
+ now: SystemTime,
+ ) -> Result {
+ match self.verifier.verify_server_cert(
+ end_entity,
+ intermediates,
+ server_name,
+ scts,
+ ocsp_response,
+ now,
+ ) {
+ Err(TlsError::InvalidCertificateData(reason))
+ if reason.contains("CertNotValidForName") =>
+ {
+ Ok(ServerCertVerified::assertion())
+ }
+ res => res,
+ }
+ }
+}
diff --git a/sqlx-core/src/net/tls/util.rs b/sqlx-core/src/net/tls/util.rs
new file mode 100644
index 0000000000..02a16ef5e1
--- /dev/null
+++ b/sqlx-core/src/net/tls/util.rs
@@ -0,0 +1,65 @@
+use crate::net::Socket;
+
+use std::io::{self, Read, Write};
+use std::task::{Context, Poll};
+
+use futures_core::ready;
+use futures_util::future;
+
+pub struct StdSocket
{
+ pub socket: S,
+ wants_read: bool,
+ wants_write: bool,
+}
+
+impl StdSocket {
+ pub fn new(socket: S) -> Self {
+ Self {
+ socket,
+ wants_read: false,
+ wants_write: false,
+ }
+ }
+
+ pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ if self.wants_write {
+ ready!(self.socket.poll_write_ready(cx))?;
+ self.wants_write = false;
+ }
+
+ if self.wants_read {
+ ready!(self.socket.poll_read_ready(cx))?;
+ self.wants_read = false;
+ }
+
+ Poll::Ready(Ok(()))
+ }
+
+ pub async fn ready(&mut self) -> io::Result<()> {
+ future::poll_fn(|cx| self.poll_ready(cx)).await
+ }
+}
+
+impl Read for StdSocket {
+ fn read(&mut self, mut buf: &mut [u8]) -> io::Result {
+ self.wants_read = true;
+ let read = self.socket.try_read(&mut buf)?;
+ self.wants_read = false;
+
+ Ok(read)
+ }
+}
+
+impl Write for StdSocket {
+ fn write(&mut self, buf: &[u8]) -> io::Result {
+ self.wants_write = true;
+ let written = self.socket.try_write(buf)?;
+ self.wants_write = false;
+ Ok(written)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ // NOTE: TCP sockets and unix sockets are both no-ops for flushes
+ Ok(())
+ }
+}
diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs
index 9c61547cbe..ade6326282 100644
--- a/sqlx-core/src/pool/connection.rs
+++ b/sqlx-core/src/pool/connection.rs
@@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::{Duration, Instant};
-use futures_intrusive::sync::SemaphoreReleaser;
+use crate::sync::AsyncSemaphoreReleaser;
use crate::connection::Connection;
use crate::database::Database;
@@ -134,13 +134,7 @@ impl Drop for PoolConnection {
fn drop(&mut self) {
// We still need to spawn a task to maintain `min_connections`.
if self.live.is_some() || self.pool.options.min_connections > 0 {
- #[cfg(not(feature = "_rt-async-std"))]
- if let Ok(handle) = sqlx_rt::Handle::try_current() {
- handle.spawn(self.return_to_pool());
- }
-
- #[cfg(feature = "_rt-async-std")]
- sqlx_rt::spawn(self.return_to_pool());
+ crate::rt::spawn(self.return_to_pool());
}
}
}
@@ -288,7 +282,7 @@ impl Floating> {
pub fn from_idle(
idle: Idle,
pool: Arc>,
- permit: SemaphoreReleaser<'_>,
+ permit: AsyncSemaphoreReleaser<'_>,
) -> Self {
Self {
inner: idle,
diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs
index 7bfae7fc78..1d7d4ba647 100644
--- a/sqlx-core/src/pool/inner.rs
+++ b/sqlx-core/src/pool/inner.rs
@@ -6,7 +6,7 @@ use crate::error::Error;
use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions};
use crossbeam_queue::ArrayQueue;
-use futures_intrusive::sync::{Semaphore, SemaphoreReleaser};
+use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};
use std::cmp;
use std::future::Future;
@@ -22,7 +22,7 @@ use std::time::{Duration, Instant};
pub(crate) struct PoolInner {
pub(super) connect_options: ::Options,
pub(super) idle_conns: ArrayQueue>,
- pub(super) semaphore: Semaphore,
+ pub(super) semaphore: AsyncSemaphore,
pub(super) size: AtomicU32,
pub(super) num_idle: AtomicUsize,
is_closed: AtomicBool,
@@ -49,7 +49,7 @@ impl PoolInner {
let pool = Self {
connect_options,
idle_conns: ArrayQueue::new(capacity),
- semaphore: Semaphore::new(options.fair, semaphore_capacity),
+ semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity),
size: AtomicU32::new(0),
num_idle: AtomicUsize::new(0),
is_closed: AtomicBool::new(false),
@@ -86,7 +86,7 @@ impl PoolInner {
self.on_closed.notify(usize::MAX);
async move {
- for permits in 1..=self.options.max_connections as usize {
+ for permits in 1..=self.options.max_connections {
// Close any currently idle connections in the pool.
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
@@ -112,7 +112,7 @@ impl PoolInner {
///
/// If we steal a permit from the parent but *don't* open a connection,
/// it should be returned to the parent.
- async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> {
+ async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> {
let parent = self
.parent()
// If we're already at the max size, we shouldn't try to steal from the parent.
@@ -182,8 +182,8 @@ impl PoolInner {
fn pop_idle<'a>(
self: &'a Arc,
- permit: SemaphoreReleaser<'a>,
- ) -> Result>, SemaphoreReleaser<'a>> {
+ permit: AsyncSemaphoreReleaser<'a>,
+ ) -> Result>, AsyncSemaphoreReleaser<'a>> {
if let Some(idle) = self.idle_conns.pop() {
self.num_idle.fetch_sub(1, Ordering::AcqRel);
Ok(Floating::from_idle(idle, (*self).clone(), permit))
@@ -211,8 +211,8 @@ impl PoolInner {
/// Try to atomically increment the pool size for a new connection.
pub(super) fn try_increment_size<'a>(
self: &'a Arc,
- permit: SemaphoreReleaser<'a>,
- ) -> Result, SemaphoreReleaser<'a>> {
+ permit: AsyncSemaphoreReleaser<'a>,
+ ) -> Result, AsyncSemaphoreReleaser<'a>> {
match self
.size
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| {
@@ -233,7 +233,7 @@ impl PoolInner {
let deadline = Instant::now() + self.options.acquire_timeout;
- sqlx_rt::timeout(
+ crate::rt::timeout(
self.options.acquire_timeout,
async {
loop {
@@ -263,7 +263,7 @@ impl PoolInner {
// If so, we're likely in the current-thread runtime if it's Tokio
// and so we should yield to let any spawned release_to_pool() tasks
// execute.
- sqlx_rt::yield_now().await;
+ crate::rt::yield_now().await;
continue;
}
};
@@ -294,7 +294,7 @@ impl PoolInner {
// result here is `Result, TimeoutError>`
// if this block does not return, sleep for the backoff timeout and try again
- match sqlx_rt::timeout(timeout, self.connect_options.connect()).await {
+ match crate::rt::timeout(timeout, self.connect_options.connect()).await {
// successfully established connection
Ok(Ok(mut raw)) => {
// See comment on `PoolOptions::after_connect`
@@ -338,7 +338,7 @@ impl PoolInner {
// If the connection is refused, wait in exponentially
// increasing steps for the server to come up,
// capped by a factor of the remaining time until the deadline
- sqlx_rt::sleep(backoff).await;
+ crate::rt::sleep(backoff).await;
backoff = cmp::min(backoff * 2, max_backoff);
}
}
@@ -467,7 +467,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) {
(None, None) => {
if pool.options.min_connections > 0 {
- sqlx_rt::spawn(async move {
+ crate::rt::spawn(async move {
pool.min_connections_maintenance(None).await;
});
}
@@ -476,7 +476,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) {
}
};
- sqlx_rt::spawn(async move {
+ crate::rt::spawn(async move {
// Immediately cancel this task if the pool is closed.
let _ = pool
.close_event()
@@ -488,9 +488,9 @@ fn spawn_maintenance_tasks(pool: &Arc>) {
if let Some(duration) = next_run.checked_duration_since(Instant::now()) {
// `async-std` doesn't have a `sleep_until()`
- sqlx_rt::sleep(duration).await;
+ crate::rt::sleep(duration).await;
} else {
- sqlx_rt::yield_now().await;
+ crate::rt::yield_now().await;
}
// Don't run the reaper right away.
@@ -544,7 +544,7 @@ impl DecrementSizeGuard {
}
}
- pub fn from_permit(pool: Arc>, mut permit: SemaphoreReleaser<'_>) -> Self {
+ pub fn from_permit(pool: Arc>, permit: AsyncSemaphoreReleaser<'_>) -> Self {
// here we effectively take ownership of the permit
permit.disarm();
Self::new_permit(pool)
diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs
index cd163c5039..feb2c9c9e4 100644
--- a/sqlx-core/src/postgres/connection/establish.rs
+++ b/sqlx-core/src/postgres/connection/establish.rs
@@ -3,7 +3,7 @@ use crate::HashMap;
use crate::common::StatementCache;
use crate::error::Error;
use crate::io::Decode;
-use crate::postgres::connection::{sasl, stream::PgStream, tls};
+use crate::postgres::connection::{sasl, stream::PgStream};
use crate::postgres::message::{
Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup,
};
@@ -15,10 +15,8 @@ use crate::postgres::{PgConnectOptions, PgConnection};
impl PgConnection {
pub(crate) async fn establish(options: &PgConnectOptions) -> Result {
- let mut stream = PgStream::connect(options).await?;
-
// Upgrade to TLS if we were asked to and the server supports it
- tls::maybe_upgrade(&mut stream, options).await?;
+ let mut stream = PgStream::connect(options).await?;
// To begin a session, a frontend opens a connection to the server
// and sends a startup message.
diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs
index 325b565c3b..3252857414 100644
--- a/sqlx-core/src/postgres/connection/mod.rs
+++ b/sqlx-core/src/postgres/connection/mod.rs
@@ -73,7 +73,7 @@ impl PgConnection {
// will return when the connection is ready for another query
pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> {
- if !self.stream.wbuf.is_empty() {
+ if !self.stream.write_buffer_mut().is_empty() {
self.stream.flush().await?;
}
@@ -203,6 +203,6 @@ impl Connection for PgConnection {
#[doc(hidden)]
fn should_flush(&self) -> bool {
- !self.stream.wbuf.is_empty()
+ !self.stream.write_buffer().is_empty()
}
}
diff --git a/sqlx-core/src/postgres/connection/stream.rs b/sqlx-core/src/postgres/connection/stream.rs
index 59b5289b8e..3e76da2f48 100644
--- a/sqlx-core/src/postgres/connection/stream.rs
+++ b/sqlx-core/src/postgres/connection/stream.rs
@@ -8,8 +8,9 @@ use futures_util::SinkExt;
use log::Level;
use crate::error::Error;
-use crate::io::{BufStream, Decode, Encode};
-use crate::net::{MaybeTlsStream, Socket};
+use crate::io::{Decode, Encode};
+use crate::net::{self, BufferedSocket, Socket};
+use crate::postgres::connection::tls::MaybeUpgradeTls;
use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity};
@@ -23,7 +24,9 @@ use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity};
// is fully prepared to receive queries
pub struct PgStream {
- inner: BufStream>,
+ // A trait object is okay here as the buffering amortizes the overhead of both the dynamic
+ // function call as well as the syscall.
+ inner: BufferedSocket>,
// buffer of unreceived notification messages from `PUBLISH`
// this is set when creating a PgListener and only written to if that listener is
@@ -37,15 +40,15 @@ pub struct PgStream {
impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result {
- let socket = match options.fetch_socket() {
- Some(ref path) => Socket::connect_uds(path).await?,
- None => Socket::connect_tcp(&options.host, options.port).await?,
+ let socket_future = match options.fetch_socket() {
+ Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
+ None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
};
- let inner = BufStream::new(MaybeTlsStream::Raw(socket));
+ let socket = socket_future.await?;
Ok(Self {
- inner,
+ inner: BufferedSocket::new(socket),
notifications: None,
parameter_statuses: BTreeMap::default(),
server_version_num: None,
@@ -57,7 +60,8 @@ impl PgStream {
T: Encode<'en>,
{
self.write(message);
- self.flush().await
+ self.flush().await?;
+ Ok(())
}
// Expect a specific type and format
@@ -171,7 +175,7 @@ impl PgStream {
}
impl Deref for PgStream {
- type Target = BufStream>;
+ type Target = BufferedSocket>;
#[inline]
fn deref(&self) -> &Self::Target {
diff --git a/sqlx-core/src/postgres/connection/tls.rs b/sqlx-core/src/postgres/connection/tls.rs
index 0c780f401a..882b54ec52 100644
--- a/sqlx-core/src/postgres/connection/tls.rs
+++ b/sqlx-core/src/postgres/connection/tls.rs
@@ -1,78 +1,100 @@
-use bytes::Bytes;
+use futures_core::future::BoxFuture;
use crate::error::Error;
-use crate::postgres::connection::stream::PgStream;
+use crate::net::tls::{self, TlsConfig};
+use crate::net::{Socket, SocketIntoBox, WithSocket};
+
use crate::postgres::message::SslRequest;
use crate::postgres::{PgConnectOptions, PgSslMode};
-pub(super) async fn maybe_upgrade(
- stream: &mut PgStream,
+pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions);
+
+impl<'a> WithSocket for MaybeUpgradeTls<'a> {
+ type Output = BoxFuture<'a, crate::Result>>;
+
+ fn with_socket(self, socket: S) -> Self::Output {
+ Box::pin(maybe_upgrade(socket, self.0))
+ }
+}
+
+async fn maybe_upgrade(
+ mut socket: S,
options: &PgConnectOptions,
-) -> Result<(), Error> {
+) -> Result, Error> {
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match options.ssl_mode {
// FIXME: Implement ALLOW
- PgSslMode::Allow | PgSslMode::Disable => {}
+ PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)),
PgSslMode::Prefer => {
+ if !tls::available() {
+ return Ok(Box::new(socket));
+ }
+
// try upgrade, but its okay if we fail
- upgrade(stream, options).await?;
+ if !request_upgrade(&mut socket, options).await? {
+ return Ok(Box::new(socket));
+ }
}
PgSslMode::Require | PgSslMode::VerifyFull | PgSslMode::VerifyCa => {
- if !upgrade(stream, options).await? {
+ tls::error_if_unavailable()?;
+
+ if !request_upgrade(&mut socket, options).await? {
// upgrade failed, die
return Err(Error::Tls("server does not support TLS".into()));
}
}
}
- Ok(())
+ let accept_invalid_certs = !matches!(
+ options.ssl_mode,
+ PgSslMode::VerifyCa | PgSslMode::VerifyFull
+ );
+ let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
+
+ let config = TlsConfig {
+ accept_invalid_certs,
+ accept_invalid_hostnames,
+ hostname: &options.host,
+ root_cert_path: options.ssl_root_cert.as_ref(),
+ };
+
+ tls::handshake(socket, config, SocketIntoBox).await
}
-async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result {
+async fn request_upgrade(
+ socket: &mut impl Socket,
+ _options: &PgConnectOptions,
+) -> Result {
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11
// To initiate an SSL-encrypted connection, the frontend initially sends an
// SSLRequest message rather than a StartupMessage
- stream.send(SslRequest).await?;
+ socket.write(SslRequest::BYTES).await?;
// The server then responds with a single byte containing S or N, indicating that
// it is willing or unwilling to perform SSL, respectively.
- match stream.read::(1).await?[0] {
+ let mut response = [0u8];
+
+ socket.read(&mut &mut response[..]).await?;
+
+ match response[0] {
b'S' => {
// The server is ready and willing to accept an SSL connection
+ Ok(true)
}
b'N' => {
// The server is _unwilling_ to perform SSL
- return Ok(false);
+ Ok(false)
}
- other => {
- return Err(err_protocol!(
- "unexpected response from SSLRequest: 0x{:02x}",
- other
- ));
- }
+ other => Err(err_protocol!(
+ "unexpected response from SSLRequest: 0x{:02x}",
+ other
+ )),
}
-
- let accept_invalid_certs = !matches!(
- options.ssl_mode,
- PgSslMode::VerifyCa | PgSslMode::VerifyFull
- );
- let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
-
- stream
- .upgrade(
- &options.host,
- accept_invalid_certs,
- accept_invalid_hostnames,
- options.ssl_root_cert.as_ref(),
- )
- .await?;
-
- Ok(true)
}
diff --git a/sqlx-core/src/postgres/copy.rs b/sqlx-core/src/postgres/copy.rs
index 0bad775085..5047bddb67 100644
--- a/sqlx-core/src/postgres/copy.rs
+++ b/sqlx-core/src/postgres/copy.rs
@@ -1,16 +1,24 @@
+use std::borrow::Cow;
+use std::ops::{Deref, DerefMut};
+
+use bytes::{BufMut, Bytes};
+use futures_core::stream::BoxStream;
+
use crate::error::{Error, Result};
use crate::ext::async_stream::TryAsyncStream;
+use crate::io::AsyncRead;
use crate::pool::{Pool, PoolConnection};
use crate::postgres::connection::PgConnection;
use crate::postgres::message::{
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
};
use crate::postgres::Postgres;
-use bytes::{BufMut, Bytes};
-use futures_core::stream::BoxStream;
-use smallvec::alloc::borrow::Cow;
-use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
-use std::ops::{Deref, DerefMut};
+
+#[cfg(not(feature = "_rt-tokio"))]
+use futures_util::io::AsyncReadExt;
+
+#[cfg(feature = "_rt-tokio")]
+use tokio::io::AsyncReadExt;
impl PgConnection {
/// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
@@ -172,8 +180,16 @@ impl> PgCopyIn {
///
/// `source` will be read to the end.
///
- /// ### Note
+ /// ### Note: Completion Step Required
/// You must still call either [Self::finish] or [Self::abort] to complete the process.
+ ///
+ /// ### Note: Runtime Features
+ /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std`
+ /// depending on which runtime feature is used.
+ ///
+ /// The runtime features _used_ to be mutually exclusive, but are no longer.
+ /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version
+ /// takes precedent.
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
// this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
struct BufGuard<'s>(&'s mut Vec);
@@ -189,46 +205,34 @@ impl> PgCopyIn {
// flush any existing messages in the buffer and clear it
conn.stream.flush().await?;
- {
- let buf_stream = &mut *conn.stream;
- let stream = &mut buf_stream.stream;
-
- // ensures the buffer isn't left in an inconsistent state
- let mut guard = BufGuard(&mut buf_stream.wbuf);
-
- let buf: &mut Vec = &mut guard.0;
- buf.push(b'd'); // CopyData format code
- buf.resize(5, 0); // reserve space for the length
-
- loop {
- let read = match () {
- // Tokio lets us read into the buffer without zeroing first
- #[cfg(feature = "runtime-tokio")]
- _ if buf.len() != buf.capacity() => {
- // in case we have some data in the buffer, which can occur
- // if the previous write did not fill the buffer
- buf.truncate(5);
- source.read_buf(buf).await?
- }
- _ => {
- // should be a no-op unless len != capacity
- buf.resize(buf.capacity(), 0);
- source.read(&mut buf[5..]).await?
- }
- };
+ loop {
+ let buf = conn.stream.write_buffer_mut();
+
+ // CopyData format code and reserved space for length
+ buf.put_slice(b"d\0\0\0\0");
+
+ let read = match () {
+ // Tokio lets us read into the buffer without zeroing first
+ #[cfg(feature = "_rt-tokio")]
+ _ => source.read_buf(buf.buf_mut()).await?,
+ #[cfg(not(feature = "_rt-tokio"))]
+ _ => source.read(buf.init_remaining_mut()).await?,
+ };
+
+ if read == 0 {
+ // This will end up sending an empty `CopyData` packet but that should be fine.
+ break;
+ }
- if read == 0 {
- break;
- }
+ buf.advance(read);
- let read32 = u32::try_from(read)
- .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
+ // Write the length
+ let read32 = u32::try_from(read)
+ .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
- (&mut buf[1..]).put_u32(read32 + 4);
+ (&mut buf.get_mut()[1..]).put_u32(read32 + 4);
- stream.write_all(&buf[..read + 5]).await?;
- stream.flush().await?;
- }
+ conn.stream.flush().await?;
}
Ok(self)
diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs
index 1432ae6c06..f5ad29eb06 100644
--- a/sqlx-core/src/postgres/listener.rs
+++ b/sqlx-core/src/postgres/listener.rs
@@ -191,8 +191,8 @@ impl PgListener {
/// # use sqlx_core::postgres::PgListener;
/// # use sqlx_core::error::Error;
/// #
- /// # #[cfg(feature = "_rt-async-std")]
- /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move {
+ /// # #[cfg(feature = "_rt")]
+ /// # sqlx::__rt::test_block_on(async move {
/// # let mut listener = PgListener::connect("postgres:// ...").await?;
/// loop {
/// // ask for next notification, re-connecting (transparently) if needed
@@ -200,7 +200,7 @@ impl PgListener {
///
/// // handle notification, do something interesting
/// }
- /// # Ok(())
+ /// # Result::<(), Error>::Ok(())
/// # }).unwrap();
/// ```
pub async fn recv(&mut self) -> Result {
@@ -222,8 +222,8 @@ impl PgListener {
/// # use sqlx_core::postgres::PgListener;
/// # use sqlx_core::error::Error;
/// #
- /// # #[cfg(feature = "_rt-async-std")]
- /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move {
+ /// # #[cfg(feature = "_rt")]
+ /// # sqlx::__rt::test_block_on(async move {
/// # let mut listener = PgListener::connect("postgres:// ...").await?;
/// loop {
/// // start handling notifications, connecting if needed
@@ -233,7 +233,7 @@ impl PgListener {
///
/// // connection lost, do something interesting
/// }
- /// # Ok(())
+ /// # Result::<(), Error>::Ok(())
/// # }).unwrap();
/// ```
pub async fn try_recv(&mut self) -> Result