From 7f0d1af825dd07ddc214e663a0e59244fa665dad Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 4 Aug 2022 13:39:35 -0700 Subject: [PATCH] WIP rt refactors --- .github/workflows/sqlx.yml | 42 +-- Cargo.toml | 53 ++-- README.md | 39 ++- sqlx-bench/Cargo.toml | 5 - sqlx-bench/benches/pg_pool.rs | 12 +- sqlx-bench/benches/sqlite_fetch_all.rs | 2 +- sqlx-core/src/column.rs | 2 + sqlx-core/src/common/mod.rs | 2 +- sqlx-core/src/connection.rs | 6 +- sqlx-core/src/error.rs | 12 +- sqlx-core/src/fs.rs | 96 +++++++ sqlx-core/src/io/buf_stream.rs | 56 ---- sqlx-core/src/io/mod.rs | 14 +- sqlx-core/src/io/read_buf.rs | 35 +++ sqlx-core/src/lib.rs | 17 +- sqlx-core/src/migrate/migrator.rs | 6 +- sqlx-core/src/migrate/source.rs | 17 +- sqlx-core/src/mssql/connection/mod.rs | 22 +- sqlx-core/src/mssql/connection/stream.rs | 36 +-- sqlx-core/src/mysql/connection/auth.rs | 2 +- sqlx-core/src/mysql/connection/establish.rs | 112 +++++--- sqlx-core/src/mysql/connection/mod.rs | 2 +- sqlx-core/src/mysql/connection/stream.rs | 75 ++--- sqlx-core/src/mysql/connection/tls.rs | 111 +++++--- sqlx-core/src/mysql/options/mod.rs | 8 +- sqlx-core/src/net/mod.rs | 17 +- sqlx-core/src/net/socket.rs | 134 --------- sqlx-core/src/net/socket/buffered.rs | 234 ++++++++++++++++ sqlx-core/src/net/socket/mod.rs | 259 ++++++++++++++++++ sqlx-core/src/net/tls/mod.rs | 241 +++------------- sqlx-core/src/net/tls/rustls.rs | 108 -------- sqlx-core/src/net/tls/tls_native_tls.rs | 82 ++++++ sqlx-core/src/net/tls/tls_rustls.rs | 184 +++++++++++++ sqlx-core/src/net/tls/util.rs | 65 +++++ sqlx-core/src/pool/connection.rs | 12 +- sqlx-core/src/pool/inner.rs | 36 +-- .../src/postgres/connection/establish.rs | 6 +- sqlx-core/src/postgres/connection/mod.rs | 4 +- sqlx-core/src/postgres/connection/stream.rs | 24 +- sqlx-core/src/postgres/connection/tls.rs | 94 ++++--- sqlx-core/src/postgres/copy.rs | 88 +++--- sqlx-core/src/postgres/listener.rs | 20 +- sqlx-core/src/postgres/message/ssl_request.rs | 8 +- sqlx-core/src/postgres/options/mod.rs | 12 +- sqlx-core/src/query.rs | 4 +- sqlx-core/src/query_as.rs | 4 +- sqlx-core/src/query_builder.rs | 2 +- sqlx-core/src/query_scalar.rs | 4 +- sqlx-core/src/rt/mod.rs | 168 ++++++++++++ sqlx-core/src/rt/rt_async_std/mod.rs | 1 + sqlx-core/src/rt/rt_async_std/socket.rs | 55 ++++ sqlx-core/src/rt/rt_tokio/mod.rs | 5 + sqlx-core/src/rt/rt_tokio/socket.rs | 55 ++++ sqlx-core/src/sqlite/migrate.rs | 2 +- sqlx-core/src/sqlite/options/mod.rs | 4 +- sqlx-core/src/sqlite/testing/mod.rs | 8 +- sqlx-core/src/statement.rs | 1 + sqlx-core/src/sync.rs | 145 ++++++++++ sqlx-core/src/testing/mod.rs | 7 +- sqlx-macros/src/query/mod.rs | 43 ++- sqlx-rt/Cargo.toml | 2 + sqlx-rt/src/connect.rs | 6 + sqlx-rt/src/lib.rs | 2 + src/lib.md | 62 +++++ src/lib.rs | 13 +- src/macros/test.md | 4 +- tests/mysql/mysql.rs | 4 +- tests/postgres/postgres.rs | 20 +- tests/sqlite/sqlcipher.rs | 9 +- tests/sqlite/sqlite.db | Bin 36864 -> 36864 bytes tests/sqlite/sqlite.rs | 23 +- 71 files changed, 2088 insertions(+), 977 deletions(-) create mode 100644 sqlx-core/src/fs.rs create mode 100644 sqlx-core/src/io/read_buf.rs delete mode 100644 sqlx-core/src/net/socket.rs create mode 100644 sqlx-core/src/net/socket/buffered.rs create mode 100644 sqlx-core/src/net/socket/mod.rs delete mode 100644 sqlx-core/src/net/tls/rustls.rs create mode 100644 sqlx-core/src/net/tls/tls_native_tls.rs create mode 100644 sqlx-core/src/net/tls/tls_rustls.rs create mode 100644 sqlx-core/src/net/tls/util.rs create mode 100644 sqlx-core/src/rt/mod.rs create mode 100644 sqlx-core/src/rt/rt_async_std/mod.rs create mode 100644 sqlx-core/src/rt/rt_async_std/socket.rs create mode 100644 sqlx-core/src/rt/rt_tokio/mod.rs create mode 100644 sqlx-core/src/rt/rt_tokio/socket.rs create mode 100644 sqlx-core/src/sync.rs create mode 100644 sqlx-rt/src/connect.rs create mode 100644 src/lib.md diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 2083eebc1c..7a44ede969 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -33,7 +33,7 @@ jobs: strategy: matrix: runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] steps: - uses: actions/checkout@v2 @@ -53,14 +53,14 @@ jobs: args: > --manifest-path sqlx-core/Cargo.toml --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - uses: actions-rs/cargo@v1 with: command: check args: > --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }},macros + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros test: name: Unit Test @@ -68,7 +68,7 @@ jobs: strategy: matrix: runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] steps: - uses: actions/checkout@v2 @@ -87,7 +87,7 @@ jobs: command: test args: > --manifest-path sqlx-core/Cargo.toml - --features offline,all-databases,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features offline,all-databases,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} cli: name: CLI Binaries @@ -139,7 +139,6 @@ jobs: strategy: matrix: runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -161,7 +160,7 @@ jobs: command: test args: > --no-default-features - --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }} -- --test-threads=1 env: @@ -176,7 +175,7 @@ jobs: matrix: postgres: [14, 10] runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] needs: check steps: - uses: actions/checkout@v2 @@ -199,7 +198,7 @@ jobs: with: command: build args: > - --features postgres,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features postgres,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: | docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 --name postgres_${{ matrix.postgres }} postgres_${{ matrix.postgres }} @@ -210,7 +209,7 @@ jobs: command: test args: > --no-default-features - --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx # FIXME: needed to disable `ltree` tests in Postgres 9.6 @@ -218,11 +217,12 @@ jobs: RUSTFLAGS: --cfg postgres_${{ matrix.postgres }} - uses: actions-rs/cargo@v1 + if: matrix.tls != 'none' with: command: test args: > --no-default-features - --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx?sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt # FIXME: needed to disable `ltree` tests in Postgres 9.6 @@ -236,7 +236,7 @@ jobs: matrix: mysql: [8, 5_7] runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] needs: check steps: - uses: actions/checkout@v2 @@ -255,7 +255,7 @@ jobs: with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features mysql,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mysql_${{ matrix.mysql }} - run: sleep 60 @@ -265,7 +265,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled @@ -276,7 +276,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -287,7 +287,7 @@ jobs: matrix: mariadb: [10_6, 10_3] runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] needs: check steps: - uses: actions/checkout@v2 @@ -306,7 +306,7 @@ jobs: with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features mysql,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -316,7 +316,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -327,7 +327,7 @@ jobs: matrix: mssql: [2019, 2017] runtime: [async-std, tokio, actix] - tls: [native-tls, rustls] + tls: [native, rustls, none] needs: check steps: - uses: actions/checkout@v2 @@ -346,7 +346,7 @@ jobs: with: command: build args: > - --features mssql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features mssql,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 1433:1433 mssql_${{ matrix.mssql }} - run: sleep 80 # MSSQL takes a "bit" to startup @@ -356,6 +356,6 @@ jobs: command: test args: > --no-default-features - --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} env: DATABASE_URL: mssql://sa:Password123!@localhost/sqlx diff --git a/Cargo.toml b/Cargo.toml index 5b55757f1d..7ac0a51a75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ members = [ ".", "sqlx-core", - "sqlx-rt", "sqlx-macros", "sqlx-test", "sqlx-cli", @@ -45,8 +44,7 @@ default = ["macros", "migrate"] macros = ["sqlx-macros"] migrate = ["sqlx-macros/migrate", "sqlx-core/migrate"] -# [deprecated] TLS is not possible to disable due to it being conditional on multiple features -# Hopefully Cargo can handle this in the future +# [deprecated] enabling TLS requires choosing a specific backend tls = [] # offline building support in `sqlx-macros` @@ -69,36 +67,28 @@ all-types = [ "git2", ] -# previous runtimes, available as features for error messages better than just -# "feature doesn't exist" -runtime-actix = [] -runtime-async-std = [] -runtime-tokio = [] +# Base runtime features without TLS +runtime-actix = ["_rt-tokio", "sqlx-core/runtime-tokio", "sqlx-macros/runtime-tokio"] +runtime-async-std = ["_rt-async-std", "sqlx-core/runtime-async-std", "sqlx-macros/runtime-async-std"] +runtime-tokio = ["_rt-tokio", "sqlx-core/runtime-tokio", "sqlx-macros/runtime-tokio"] -# actual runtimes -runtime-actix-native-tls = ["runtime-tokio-native-tls"] -runtime-async-std-native-tls = [ - "sqlx-core/runtime-async-std-native-tls", - "sqlx-macros/runtime-async-std-native-tls", - "_rt-async-std", -] -runtime-tokio-native-tls = [ - "sqlx-core/runtime-tokio-native-tls", - "sqlx-macros/runtime-tokio-native-tls", - "_rt-tokio", -] +# TLS features +# didn't call this `tls-native-tls` because of the annoying tautology +tls-native = ["sqlx-core/tls-native", "sqlx-macros/tls-native"] +tls-rustls = ["sqlx-core/tls-rustls", "sqlx-macros/tls-rustls"] +# No-op feature used by the workflows to compile without TLS enabled. Not meant for general use. +tls-none = [] + +# Legacy Runtime + TLS features +runtime-actix-native-tls = ["runtime-tokio-native-tls"] runtime-actix-rustls = ["runtime-tokio-rustls"] -runtime-async-std-rustls = [ - "sqlx-core/runtime-async-std-rustls", - "sqlx-macros/runtime-async-std-rustls", - "_rt-async-std", -] -runtime-tokio-rustls = [ - "sqlx-core/runtime-tokio-rustls", - "sqlx-macros/runtime-tokio-rustls", - "_rt-tokio", -] + +runtime-async-std-native-tls = ["runtime-async-std", "tls-native"] +runtime-async-std-rustls = ["runtime-async-std", "tls-rustls"] + +runtime-tokio-native-tls = ["runtime-tokio", "tls-native"] +runtime-tokio-rustls = ["runtime-tokio", "tls-rustls"] # for conditional compilation _rt-async-std = [] @@ -137,7 +127,6 @@ async-std = { version = "1.10.0", features = ["attributes"] } tokio = { version = "1.15.0", features = ["full"] } dotenvy = "0.15.0" trybuild = "1.0.53" -sqlx-rt = { path = "./sqlx-rt" } sqlx-test = { path = "./sqlx-test" } paste = "1.0.6" serde = { version = "1.0.132", features = ["derive"] } @@ -203,7 +192,7 @@ path = "tests/sqlite/derives.rs" required-features = ["sqlite", "macros"] [[test]] -name = "sqlcipher" +name = "sqlite-sqlcipher" path = "tests/sqlite/sqlcipher.rs" required-features = ["sqlite"] diff --git a/README.md b/README.md index 7c7b66a237..0c2499d231 100644 --- a/README.md +++ b/README.md @@ -123,28 +123,57 @@ SQLx is compatible with the [`async-std`], [`tokio`] and [`actix`] runtimes; and ```toml # Cargo.toml [dependencies] +# PICK ONE OF THE FOLLOWING: + +# tokio (no TLS) +sqlx = { version = "0.6", features = [ "runtime-tokio" ] } +# tokio + native-tls +sqlx = { version = "0.6", features = [ "runtime-tokio", "tls-native" ] } # tokio + rustls -sqlx = { version = "0.6", features = [ "runtime-tokio-rustls" ] } +sqlx = { version = "0.6", features = [ "runtime-tokio", "tls-rustls" ] } + +# async-std (no TLS) +sqlx = { version = "0.6", features = [ "runtime-async-std" ] } # async-std + native-tls -sqlx = { version = "0.6", features = [ "runtime-async-std-native-tls" ] } +sqlx = { version = "0.6", features = [ "runtime-async-std", "tls-native" ] } +# async-std + rustls +sqlx = { version = "0.6", features = [ "runtime-async-std", "tls-rustls" ] } ``` -The runtime and TLS backend not being separate feature sets to select is a workaround for a [Cargo issue](https://github.com/rust-lang/cargo/issues/3494). - #### Cargo Feature Flags +For backwards-compatibility reasons, the runtime and TLS features can either be chosen together as a single feature, +or separately. + +For forward-compatibility, you should use the separate runtime and TLS features as the combination features may +be removed in the future. + +- `runtime-async-std`: Use the `async-std` runtime without enabling a TLS backend. + - `runtime-async-std-native-tls`: Use the `async-std` runtime and `native-tls` TLS backend. - `runtime-async-std-rustls`: Use the `async-std` runtime and `rustls` TLS backend. +- `runtime-tokio`: Use the `tokio` runtime without enabling a TLS backend. + - `runtime-tokio-native-tls`: Use the `tokio` runtime and `native-tls` TLS backend. - `runtime-tokio-rustls`: Use the `tokio` runtime and `rustls` TLS backend. +- `runtime-actix`: Use the `actix` runtime without enabling a TLS backend. + - `runtime-actix-native-tls`: Use the `actix` runtime and `native-tls` TLS backend. - `runtime-actix-rustls`: Use the `actix` runtime and `rustls` TLS backend. + - Actix-web is fully compatible with Tokio and so a separate runtime feature is no longer needed. + The above three features exist only for backwards compatibility, and are in fact merely aliases to their + `runtime-tokio` counterparts. + +- `tls-native`: Use the `native-tls` TLS backend (OpenSSL on *nix, SChannel on Windows, Secure Transport on macOS). + +- `tls-rustls`: Use the `rustls` TLS backend (crossplatform backend, only supports TLS 1.2 and 1.3). + - `postgres`: Add support for the Postgres database server. - `mysql`: Add support for the MySQL/MariaDB database server. @@ -177,8 +206,6 @@ sqlx = { version = "0.6", features = [ "runtime-async-std-native-tls" ] } - `json`: Add support for `JSON` and `JSONB` (in postgres) using the `serde_json` crate. -- `tls`: Add support for TLS connections. - - `offline`: Enables building the macros in offline mode when a live database is not available (such as CI). - Requires `sqlx-cli` installed to use. See [sqlx-cli/README.md][readme-offline]. diff --git a/sqlx-bench/Cargo.toml b/sqlx-bench/Cargo.toml index 58504423ee..5fcb35068a 100644 --- a/sqlx-bench/Cargo.toml +++ b/sqlx-bench/Cargo.toml @@ -9,21 +9,17 @@ publish = false runtime-actix-native-tls = ["runtime-tokio-native-tls"] runtime-async-std-native-tls = [ "sqlx/runtime-async-std-native-tls", - "sqlx-rt/runtime-async-std-native-tls", ] runtime-tokio-native-tls = [ "sqlx/runtime-tokio-native-tls", - "sqlx-rt/runtime-tokio-native-tls", ] runtime-actix-rustls = ["runtime-tokio-rustls"] runtime-async-std-rustls = [ "sqlx/runtime-async-std-rustls", - "sqlx-rt/runtime-async-std-rustls", ] runtime-tokio-rustls = [ "sqlx/runtime-tokio-rustls", - "sqlx-rt/runtime-tokio-rustls", ] postgres = ["sqlx/postgres"] @@ -34,7 +30,6 @@ criterion = "0.3.3" dotenvy = "0.15.0" once_cell = "1.4" sqlx = { version = "0.6", path = "../", default-features = false, features = ["macros"] } -sqlx-rt = { version = "0.6", path = "../sqlx-rt", default-features = false } chrono = "0.4.19" diff --git a/sqlx-bench/benches/pg_pool.rs b/sqlx-bench/benches/pg_pool.rs index 0e4f2e78d5..4fdce3b6b6 100644 --- a/sqlx-bench/benches/pg_pool.rs +++ b/sqlx-bench/benches/pg_pool.rs @@ -23,7 +23,7 @@ fn bench_pgpool_acquire(c: &mut Criterion) { } fn do_bench_acquire(b: &mut Bencher, concurrent: u32, fair: bool) { - let pool = sqlx_rt::block_on( + let pool = sqlx::__rt::block_on( PgPoolOptions::new() // we don't want timeouts because we want to see how the pool degrades .acquire_timeout(Duration::from_secs(3600)) @@ -41,8 +41,8 @@ fn do_bench_acquire(b: &mut Bencher, concurrent: u32, fair: bool) { for _ in 0..concurrent { let pool = pool.clone(); - sqlx_rt::enter_runtime(|| { - sqlx_rt::spawn(async move { + sqlx::__rt::enter_runtime(|| { + sqlx::__rt::spawn(async move { while !pool.is_closed() { let conn = match pool.acquire().await { Ok(conn) => conn, @@ -51,7 +51,7 @@ fn do_bench_acquire(b: &mut Bencher, concurrent: u32, fair: bool) { }; // pretend we're using the connection - sqlx_rt::sleep(Duration::from_micros(500)).await; + sqlx::__rt::sleep(Duration::from_micros(500)).await; drop(criterion::black_box(conn)); } }) @@ -59,7 +59,7 @@ fn do_bench_acquire(b: &mut Bencher, concurrent: u32, fair: bool) { } b.iter_custom(|iters| { - sqlx_rt::block_on(async { + sqlx::__rt::block_on(async { // take the start time inside the future to make sure we only count once it's running let start = Instant::now(); for _ in 0..iters { @@ -73,7 +73,7 @@ fn do_bench_acquire(b: &mut Bencher, concurrent: u32, fair: bool) { }) }); - sqlx_rt::block_on(pool.close()); + sqlx::__rt::block_on(pool.close()); } criterion_group!(pg_pool, bench_pgpool_acquire); diff --git a/sqlx-bench/benches/sqlite_fetch_all.rs b/sqlx-bench/benches/sqlite_fetch_all.rs index 690b1ddd9a..ba6f6f71d5 100644 --- a/sqlx-bench/benches/sqlite_fetch_all.rs +++ b/sqlx-bench/benches/sqlite_fetch_all.rs @@ -8,7 +8,7 @@ struct Test { } fn main() -> sqlx::Result<()> { - sqlx_rt::block_on(async { + sqlx::__rt::block_on(async { let mut conn = sqlx::SqliteConnection::connect("sqlite://test.db?mode=rwc").await?; let delete_sql = "DROP TABLE IF EXISTS test"; conn.execute(delete_sql).await?; diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index e670e3b4cd..f09ce57013 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -55,6 +55,7 @@ impl + ?Sized> ColumnIndex for &'_ I { } } +#[macro_export] macro_rules! impl_column_index_for_row { ($R:ident) => { impl crate::column::ColumnIndex<$R> for usize { @@ -71,6 +72,7 @@ macro_rules! impl_column_index_for_row { }; } +#[macro_export] macro_rules! impl_column_index_for_statement { ($S:ident) => { impl crate::column::ColumnIndex<$S<'_>> for usize { diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs index 63ed52815b..200445f539 100644 --- a/sqlx-core/src/common/mod.rs +++ b/sqlx-core/src/common/mod.rs @@ -1,6 +1,6 @@ mod statement_cache; -pub(crate) use statement_cache::StatementCache; +pub use statement_cache::StatementCache; use std::fmt::{Debug, Formatter}; use std::ops::{Deref, DerefMut}; diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 80469ea3e5..8d593b4bc0 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -132,7 +132,7 @@ pub trait Connection: Send { } #[derive(Clone, Debug)] -pub(crate) struct LogSettings { +pub struct LogSettings { pub(crate) statements_level: LevelFilter, pub(crate) slow_statements_level: LevelFilter, pub(crate) slow_statements_duration: Duration, @@ -149,10 +149,10 @@ impl Default for LogSettings { } impl LogSettings { - pub(crate) fn log_statements(&mut self, level: LevelFilter) { + pub fn log_statements(&mut self, level: LevelFilter) { self.statements_level = level; } - pub(crate) fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) { + pub fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) { self.slow_statements_level = level; self.slow_statements_duration = duration; } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index cc8ff523a6..06c80eb495 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -133,6 +133,10 @@ impl Error { pub(crate) fn config(err: impl StdError + Send + Sync + 'static) -> Self { Error::Configuration(err.into()) } + + pub(crate) fn tls(err: impl Into>) -> Self { + Error::Tls(err.into()) + } } pub(crate) fn mismatched_types>(ty: &DB::TypeInfo) -> BoxDynError { @@ -250,14 +254,6 @@ impl From for Error { } } -#[cfg(feature = "_tls-native-tls")] -impl From for Error { - #[inline] - fn from(error: sqlx_rt::native_tls::Error) -> Self { - Error::Tls(Box::new(error)) - } -} - // Format an error message as a `Protocol` error macro_rules! err_protocol { ($expr:expr) => { diff --git a/sqlx-core/src/fs.rs b/sqlx-core/src/fs.rs new file mode 100644 index 0000000000..0993cbeec6 --- /dev/null +++ b/sqlx-core/src/fs.rs @@ -0,0 +1,96 @@ +use std::ffi::OsString; +use std::fs::Metadata; +use std::io; +use std::path::{Path, PathBuf}; + +use crate::rt; + +pub struct ReadDir { + inner: Option, +} + +pub struct DirEntry { + pub path: PathBuf, + pub file_name: OsString, + pub metadata: Metadata, +} + +// Filesystem operations are generally not capable of being non-blocking +// so Tokio and async-std don't bother; they just send the work to a blocking thread pool. +// +// We save on code duplication here by just implementing the same strategy ourselves +// using the runtime's `spawn_blocking()` primitive. + +pub async fn read>(path: P) -> io::Result> { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::read(path)).await +} + +pub async fn read_to_string>(path: P) -> io::Result { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::read_to_string(path)).await +} + +pub async fn create_dir_all>(path: P) -> io::Result<()> { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::create_dir_all(path)).await +} + +pub async fn remove_file>(path: P) -> io::Result<()> { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::remove_file(path)).await +} + +pub async fn remove_dir>(path: P) -> io::Result<()> { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::remove_dir(path)).await +} + +pub async fn remove_dir_all>(path: P) -> io::Result<()> { + let path = PathBuf::from(path.as_ref()); + rt::spawn_blocking(move || std::fs::remove_dir_all(path)).await +} + +pub async fn read_dir(path: PathBuf) -> io::Result { + let read_dir = rt::spawn_blocking(move || std::fs::read_dir(path)).await?; + + Ok(ReadDir { + inner: Some(read_dir), + }) +} + +impl ReadDir { + pub async fn next(&mut self) -> io::Result> { + if let Some(mut read_dir) = self.inner.take() { + let maybe = rt::spawn_blocking(move || { + let entry = read_dir.next().transpose()?; + + entry + .map(|entry| -> io::Result<_> { + Ok(( + read_dir, + DirEntry { + path: entry.path(), + file_name: entry.file_name(), + // We always want the metadata as well so might as well fetch + // it in the same blocking call. + metadata: entry.metadata()?, + }, + )) + }) + .transpose() + }) + .await?; + + match maybe { + Some((read_dir, entry)) => { + self.inner = Some(read_dir); + Ok(Some(entry)) + } + None => Ok(None), + } + } else { + Ok(None) + } + } +} diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 8f376cbfb0..a21cfe288d 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -103,59 +103,3 @@ where &mut self.stream } } - -// Holds a buffer which has been temporarily extended, so that -// we can read into it. Automatically shrinks the buffer back -// down if the read is cancelled. -struct BufTruncator<'a> { - buf: &'a mut BytesMut, - filled_len: usize, -} - -impl<'a> BufTruncator<'a> { - fn new(buf: &'a mut BytesMut) -> Self { - let filled_len = buf.len(); - Self { buf, filled_len } - } - fn reserve(&mut self, space: usize) { - self.buf.resize(self.filled_len + space, 0); - } - async fn read(&mut self, stream: &mut S) -> Result { - let n = stream.read(&mut self.buf[self.filled_len..]).await?; - self.filled_len += n; - Ok(n) - } - fn is_full(&self) -> bool { - self.filled_len >= self.buf.len() - } -} - -impl Drop for BufTruncator<'_> { - fn drop(&mut self) { - self.buf.truncate(self.filled_len); - } -} - -async fn read_raw_into( - stream: &mut S, - buf: &mut BytesMut, - cnt: usize, -) -> Result<(), Error> { - let mut buf = BufTruncator::new(buf); - buf.reserve(cnt); - - while !buf.is_full() { - let n = buf.read(stream).await?; - - if n == 0 { - // a zero read when we had space in the read buffer - // should be treated as an EOF - - // and an unexpected EOF means the server told us to go away - - return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); - } - } - - Ok(()) -} diff --git a/sqlx-core/src/io/mod.rs b/sqlx-core/src/io/mod.rs index f994965400..44fe89fb27 100644 --- a/sqlx-core/src/io/mod.rs +++ b/sqlx-core/src/io/mod.rs @@ -1,12 +1,20 @@ mod buf; mod buf_mut; -mod buf_stream; +// mod buf_stream; mod decode; mod encode; -mod write_and_flush; +mod read_buf; +// mod write_and_flush; pub use buf::BufExt; pub use buf_mut::BufMutExt; -pub use buf_stream::BufStream; +//pub use buf_stream::BufStream; pub use decode::Decode; pub use encode::Encode; +pub use read_buf::ReadBuf; + +#[cfg(not(feature = "_rt-tokio"))] +pub use futures_io::AsyncRead; + +#[cfg(feature = "_rt-tokio")] +pub use tokio::io::AsyncRead; diff --git a/sqlx-core/src/io/read_buf.rs b/sqlx-core/src/io/read_buf.rs new file mode 100644 index 0000000000..c32f37befd --- /dev/null +++ b/sqlx-core/src/io/read_buf.rs @@ -0,0 +1,35 @@ +use bytes::{BufMut, BytesMut}; + +/// An extension for [`BufMut`] for getting a writeable buffer in safe code. +pub trait ReadBuf: BufMut { + /// Get the full capacity of this buffer as a safely initialized slice. + fn init_mut(&mut self) -> &mut [u8]; +} + +impl ReadBuf for &'_ mut [u8] { + #[inline(always)] + fn init_mut(&mut self) -> &mut [u8] { + self + } +} + +impl ReadBuf for BytesMut { + #[inline(always)] + fn init_mut(&mut self) -> &mut [u8] { + // `self.remaining_mut()` returns `usize::MAX - self.len()` + let remaining = self.capacity() - self.len(); + + // I'm hoping for most uses that this operation is elided by the optimizer. + self.put_bytes(0, remaining); + + self + } +} + +#[test] +fn test_read_buf_bytes_mut() { + let mut buf = BytesMut::with_capacity(8); + buf.put_u32(0x12345678); + + assert_eq!(buf.init_mut(), [0x12, 0x34, 0x56, 0x78, 0, 0, 0, 0]); +} diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 83e0203b93..cc653689ba 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -1,5 +1,5 @@ //! Core of SQLx, the rust SQL toolkit. -//! Not intended to be used directly. +//! Not intended to be used directly; semver exempt. #![recursion_limit = "512"] #![warn(future_incompatible, rust_2018_idioms)] #![allow(clippy::needless_doctest_main, clippy::type_complexity)] @@ -57,19 +57,22 @@ pub mod column; #[macro_use] pub mod statement; -mod common; +pub mod common; pub use either::Either; pub mod database; pub mod describe; pub mod executor; pub mod from_row; -mod io; -mod logger; -mod net; +pub mod fs; +pub mod io; +pub mod logger; +pub mod net; pub mod query_as; pub mod query_builder; pub mod query_scalar; pub mod row; +pub mod rt; +pub mod sync; pub mod type_info; pub mod value; @@ -107,8 +110,8 @@ pub mod mssql; #[cfg(feature = "migrate")] pub mod testing; -pub use sqlx_rt::test_block_on; +pub use error::{Error, Result}; /// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. -use ahash::AHashMap as HashMap; +pub use ahash::AHashMap as HashMap; //type HashMap = std::collections::HashMap; diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index f120e3e1ca..4ae95e82eb 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -39,7 +39,7 @@ impl Migrator { /// ```rust,no_run /// # use sqlx_core::migrate::MigrateError; /// # fn main() -> Result<(), MigrateError> { - /// # sqlx_rt::block_on(async move { + /// # sqlx::__rt::test_block_on(async move { /// # use sqlx_core::migrate::Migrator; /// use std::path::Path; /// @@ -94,7 +94,7 @@ impl Migrator { /// # use sqlx_core::migrate::MigrateError; /// # #[cfg(feature = "sqlite")] /// # fn main() -> Result<(), MigrateError> { - /// # sqlx_rt::block_on(async move { + /// # sqlx::__rt::test_block_on(async move { /// # use sqlx_core::migrate::Migrator; /// let m = Migrator::new(std::path::Path::new("./migrations")).await?; /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?; @@ -173,7 +173,7 @@ impl Migrator { /// # use sqlx_core::migrate::MigrateError; /// # #[cfg(feature = "sqlite")] /// # fn main() -> Result<(), MigrateError> { - /// # sqlx_rt::block_on(async move { + /// # sqlx::__rt::test_block_on(async move { /// # use sqlx_core::migrate::Migrator; /// let m = Migrator::new(std::path::Path::new("./migrations")).await?; /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?; diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index cd0cdca39d..609f4fdeaa 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,8 +1,8 @@ use crate::error::BoxDynError; +use crate::fs; use crate::migrate::{Migration, MigrationType}; use futures_core::future::BoxFuture; -use futures_util::TryStreamExt; -use sqlx_rt::fs; + use std::borrow::Cow; use std::fmt::Debug; use std::path::{Path, PathBuf}; @@ -20,21 +20,16 @@ pub trait MigrationSource<'s>: Debug { impl<'s> MigrationSource<'s> for &'s Path { fn resolve(self) -> BoxFuture<'s, Result, BoxDynError>> { Box::pin(async move { - #[allow(unused_mut)] let mut s = fs::read_dir(self.canonicalize()?).await?; let mut migrations = Vec::new(); - #[cfg(feature = "_rt-tokio")] - let mut s = tokio_stream::wrappers::ReadDirStream::new(s); - - while let Some(entry) = s.try_next().await? { - if !entry.metadata().await?.is_file() { + while let Some(entry) = s.next().await? { + if !entry.metadata.is_file() { // not a file; ignore continue; } - let file_name = entry.file_name(); - let file_name = file_name.to_string_lossy(); + let file_name = entry.file_name.to_string_lossy(); let parts = file_name.splitn(2, '_').collect::>(); @@ -52,7 +47,7 @@ impl<'s> MigrationSource<'s> for &'s Path { .replace('_', " ") .to_owned(); - let sql = fs::read_to_string(&entry.path()).await?; + let sql = fs::read_to_string(&entry.path).await?; migrations.push(Migration::new( version, diff --git a/sqlx-core/src/mssql/connection/mod.rs b/sqlx-core/src/mssql/connection/mod.rs index 8585f7cf99..2b6558c98e 100644 --- a/sqlx-core/src/mssql/connection/mod.rs +++ b/sqlx-core/src/mssql/connection/mod.rs @@ -37,22 +37,10 @@ impl Connection for MssqlConnection { fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { // NOTE: there does not seem to be a clean shutdown packet to send to MSSQL - #[cfg(feature = "_rt-async-std")] - { - use std::future::ready; - use std::net::Shutdown; - - ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed() - } - - #[cfg(feature = "_rt-tokio")] - { - use sqlx_rt::AsyncWriteExt; - - // FIXME: This is equivalent to Shutdown::Write, not Shutdown::Both like above - // https://docs.rs/tokio/1.0.1/tokio/io/trait.AsyncWriteExt.html#method.shutdown - async move { self.stream.shutdown().await.map_err(Into::into) }.boxed() - } + Box::pin(async move { + self.stream.shutdown().await?; + Ok(()) + }) } fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> { @@ -78,6 +66,6 @@ impl Connection for MssqlConnection { #[doc(hidden)] fn should_flush(&self) -> bool { - !self.stream.wbuf.is_empty() + !self.stream.write_buffer().is_empty() } } diff --git a/sqlx-core/src/mssql/connection/stream.rs b/sqlx-core/src/mssql/connection/stream.rs index 1ce061d508..79888429f8 100644 --- a/sqlx-core/src/mssql/connection/stream.rs +++ b/sqlx-core/src/mssql/connection/stream.rs @@ -1,11 +1,11 @@ use std::ops::{Deref, DerefMut}; +use std::sync::Arc; -use bytes::{Bytes, BytesMut}; -use sqlx_rt::TcpStream; +use bytes::{BufMut, Bytes, BytesMut}; use crate::error::Error; use crate::ext::ustr::UStr; -use crate::io::{BufStream, Encode}; +use crate::io::Encode; use crate::mssql::protocol::col_meta_data::ColMetaData; use crate::mssql::protocol::done::{Done, Status as DoneStatus}; use crate::mssql::protocol::env_change::EnvChange; @@ -19,12 +19,11 @@ use crate::mssql::protocol::return_status::ReturnStatus; use crate::mssql::protocol::return_value::ReturnValue; use crate::mssql::protocol::row::Row; use crate::mssql::{MssqlColumn, MssqlConnectOptions, MssqlDatabaseError}; -use crate::net::MaybeTlsStream; +use crate::net::{BufferedSocket, Socket, SocketIntoBox}; use crate::HashMap; -use std::sync::Arc; pub(crate) struct MssqlStream { - inner: BufStream>, + inner: BufferedSocket>, // how many Done (or Error) we are currently waiting for pub(crate) pending_done_count: usize, @@ -45,12 +44,10 @@ pub(crate) struct MssqlStream { impl MssqlStream { pub(super) async fn connect(options: &MssqlConnectOptions) -> Result { - let inner = BufStream::new(MaybeTlsStream::Raw( - TcpStream::connect((&*options.host, options.port)).await?, - )); + let socket = crate::net::connect_tcp(&options.host, options.port, SocketIntoBox).await?; Ok(Self { - inner, + inner: BufferedSocket::new(socket), columns: Default::default(), column_names: Default::default(), response: None, @@ -68,6 +65,7 @@ impl MssqlStream { // write out the packet header, leaving room for setting the packet length later + let starting_buf_len = self.inner.write_buffer().get().len(); let mut len_offset = 0; self.inner.write_with( @@ -78,15 +76,18 @@ impl MssqlStream { server_process_id: 0, packet_id: 1, }, + // updated by `PacketHeader::encode()` &mut len_offset, ); // write out the payload self.inner.write(payload); + let buf = self.inner.write_buffer_mut().get_mut(); + // overwrite the packet length now that we know it - let len = self.inner.wbuf.len(); - self.inner.wbuf[len_offset..(len_offset + 2)].copy_from_slice(&(len as u16).to_be_bytes()); + let len = buf.len() - starting_buf_len; + (&mut buf[len_offset..(len_offset + 2)]).put_u16(len as u16); } // receive the next packet from the database @@ -106,10 +107,13 @@ impl MssqlStream { let mut payload = BytesMut::new(); loop { - self.inner - .read_raw_into(&mut payload, (header.length - 8) as usize) + let chunk = self + .inner + .read_buffered((header.length - 8) as usize) .await?; + payload.unsplit(chunk); + if header.status.contains(Status::END_OF_MESSAGE) { break; } @@ -202,7 +206,7 @@ impl MssqlStream { } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.wbuf.is_empty() { + if !self.write_buffer().is_empty() { self.flush().await?; } @@ -222,7 +226,7 @@ impl MssqlStream { } impl Deref for MssqlStream { - type Target = BufStream>; + type Target = BufferedSocket>; fn deref(&self) -> &Self::Target { &self.inner diff --git a/sqlx-core/src/mysql/connection/auth.rs b/sqlx-core/src/mysql/connection/auth.rs index 237fd55288..bb04684dc3 100644 --- a/sqlx-core/src/mysql/connection/auth.rs +++ b/sqlx-core/src/mysql/connection/auth.rs @@ -131,7 +131,7 @@ async fn encrypt_rsa<'s>( ) -> Result, Error> { // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ - if stream.is_tls() { + if stream.is_tls { // If in a TLS stream, send the password directly in clear text return Ok(to_asciz(password)); } diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs index 5352b1a10c..1e65ca710b 100644 --- a/sqlx-core/src/mysql/connection/establish.rs +++ b/sqlx-core/src/mysql/connection/establish.rs @@ -1,26 +1,77 @@ use bytes::buf::Buf; use bytes::Bytes; +use futures_core::future::BoxFuture; use crate::common::StatementCache; use crate::error::Error; +use crate::mysql::collation::{CharSet, Collation}; use crate::mysql::connection::{tls, MySqlStream, MAX_PACKET_SIZE}; use crate::mysql::protocol::connect::{ AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse, }; use crate::mysql::protocol::Capabilities; -use crate::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; +use crate::mysql::{MySqlConnectOptions, MySqlConnection}; +use crate::net::{Socket, WithSocket}; impl MySqlConnection { pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { - let mut stream: MySqlStream = MySqlStream::connect(options).await?; + let do_handshake = DoHandshake::new(options)?; - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html + let handshake = match &options.socket { + Some(path) => crate::net::connect_uds(path, do_handshake).await?, + None => crate::net::connect_tcp(&options.host, options.port, do_handshake).await?, + }; + + let stream = handshake.await?; + + Ok(Self { + stream, + transaction_depth: 0, + cache_statement: StatementCache::new(options.statement_cache_capacity), + log_settings: options.log_settings.clone(), + }) + } +} + +struct DoHandshake<'a> { + options: &'a MySqlConnectOptions, + charset: CharSet, + collation: Collation, +} + +impl<'a> DoHandshake<'a> { + fn new(options: &'a MySqlConnectOptions) -> Result { + let charset: CharSet = options.charset.parse()?; + let collation: Collation = options + .collation + .as_deref() + .map(|collation| collation.parse()) + .transpose()? + .unwrap_or_else(|| charset.default_collation()); + + Ok(Self { + options, + charset, + collation, + }) + } + + async fn do_handshake(self, socket: S) -> Result { + let DoHandshake { + options, + charset, + collation, + } = self; + + let mut stream = MySqlStream::with_socket(charset, collation, options, socket); + + // https://dev.mysql.com/doc/internals/en/connection-phase.html // https://mariadb.com/kb/en/connection/ let handshake: Handshake = stream.recv_packet().await?.decode()?; let mut plugin = handshake.auth_plugin; - let mut nonce = handshake.auth_plugin_data; + let nonce = handshake.auth_plugin_data; // FIXME: server version parse is a bit ugly // expecting MAJOR.MINOR.PATCH @@ -54,39 +105,7 @@ impl MySqlConnection { stream.capabilities &= handshake.server_capabilities; stream.capabilities |= Capabilities::PROTOCOL_41; - if matches!(options.ssl_mode, MySqlSslMode::Disabled) { - // remove the SSL capability if SSL has been explicitly disabled - stream.capabilities.remove(Capabilities::SSL); - } - - // Upgrade to TLS if we were asked to and the server supports it - - #[cfg(feature = "_tls-rustls")] - { - // To aid in debugging: https://github.com/rustls/rustls/issues/893 - - let local_addr = stream.local_addr(); - - match tls::maybe_upgrade(&mut stream, options).await { - Ok(()) => (), - #[cfg(feature = "_tls-rustls")] - Err(Error::Io(ioe)) => { - if let Some(&rustls::Error::CorruptMessage) = - ioe.get_ref().and_then(|e| e.downcast_ref()) - { - log::trace!("got corrupt message on socket {:?}", local_addr); - } - - return Err(Error::Io(ioe)); - } - Err(e) => return Err(e), - } - } - - #[cfg(not(feature = "_tls-rustls"))] - { - tls::maybe_upgrade(&mut stream, options).await? - } + let mut stream = tls::maybe_upgrade(stream, self.options).await?; let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) { Some(plugin.scramble(&mut stream, password, &nonce).await?) @@ -118,7 +137,7 @@ impl MySqlConnection { let switch: AuthSwitchRequest = packet.decode()?; plugin = Some(switch.plugin); - nonce = switch.data.chain(Bytes::new()); + let nonce = switch.data.chain(Bytes::new()); let response = switch .plugin @@ -140,7 +159,7 @@ impl MySqlConnection { break; } - // plugin signaled to continue authentication + // plugin signaled to continue authentication } else { return Err(err_protocol!( "unexpected packet 0x{:02x} during authentication", @@ -151,11 +170,14 @@ impl MySqlConnection { } } - Ok(Self { - stream, - transaction_depth: 0, - cache_statement: StatementCache::new(options.statement_cache_capacity), - log_settings: options.log_settings.clone(), - }) + Ok(stream) + } +} + +impl<'a> WithSocket for DoHandshake<'a> { + type Output = BoxFuture<'a, Result>; + + fn with_socket(self, socket: S) -> Self::Output { + Box::pin(self.do_handshake(socket)) } } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index 1f87eaa918..f35cac8aae 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -98,7 +98,7 @@ impl Connection for MySqlConnection { #[doc(hidden)] fn should_flush(&self) -> bool { - !self.stream.wbuf.is_empty() + !self.stream.write_buffer().is_empty() } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index dd9a1235b8..5b058b8e34 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -4,22 +4,24 @@ use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufStream, Decode, Encode}; +use crate::io::{Decode, Encode}; use crate::mysql::collation::{CharSet, Collation}; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::mysql::protocol::{Capabilities, Packet}; use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError}; -use crate::net::{MaybeTlsStream, Socket}; +use crate::net::{BufferedSocket, Socket}; -pub struct MySqlStream { - stream: BufStream>, +pub struct MySqlStream> { + // Wrapping the socket in `Box` allows us to unsize in-place. + pub(crate) socket: BufferedSocket, pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, pub(crate) waiting: VecDeque, pub(crate) charset: CharSet, pub(crate) collation: Collation, + pub(crate) is_tls: bool, } #[derive(Debug, PartialEq, Eq)] @@ -31,21 +33,13 @@ pub(crate) enum Waiting { Row, } -impl MySqlStream { - pub(super) async fn connect(options: &MySqlConnectOptions) -> Result { - let charset: CharSet = options.charset.parse()?; - let collation: Collation = options - .collation - .as_deref() - .map(|collation| collation.parse()) - .transpose()? - .unwrap_or_else(|| charset.default_collation()); - - let socket = match options.socket { - Some(ref path) => Socket::connect_uds(path).await?, - None => Socket::connect_tcp(&options.host, options.port).await?, - }; - +impl MySqlStream { + pub(crate) fn with_socket( + charset: CharSet, + collation: Collation, + options: &MySqlConnectOptions, + socket: S, + ) -> Self { let mut capabilities = Capabilities::PROTOCOL_41 | Capabilities::IGNORE_SPACE | Capabilities::DEPRECATE_EOF @@ -63,20 +57,21 @@ impl MySqlStream { capabilities |= Capabilities::CONNECT_WITH_DB; } - Ok(Self { + Self { waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, collation, charset, - stream: BufStream::new(MaybeTlsStream::Raw(socket)), - }) + socket: BufferedSocket::new(socket), + is_tls: false, + } } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.stream.wbuf.is_empty() { - self.stream.flush().await?; + if !self.socket.write_buffer().is_empty() { + self.socket.flush().await?; } while !self.waiting.is_empty() { @@ -119,14 +114,15 @@ impl MySqlStream { { self.sequence_id = 0; self.write_packet(payload); - self.flush().await + self.flush().await?; + Ok(()) } pub(crate) fn write_packet<'en, T>(&mut self, payload: T) where T: Encode<'en, Capabilities>, { - self.stream + self.socket .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)); } @@ -136,14 +132,14 @@ impl MySqlStream { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header: Bytes = self.stream.read(4).await?; + let mut header: Bytes = self.socket.read(4).await?; let packet_size = header.get_uint_le(3) as usize; let sequence_id = header.get_u8(); self.sequence_id = sequence_id.wrapping_add(1); - let payload: Bytes = self.stream.read(packet_size).await?; + let payload: Bytes = self.socket.read(packet_size).await?; // TODO: packet compression // TODO: packet joining @@ -195,18 +191,31 @@ impl MySqlStream { Ok(()) } + + pub fn boxed_socket(self) -> MySqlStream { + MySqlStream { + socket: self.socket.boxed(), + server_version: self.server_version, + capabilities: self.capabilities, + sequence_id: self.sequence_id, + waiting: self.waiting, + charset: self.charset, + collation: self.collation, + is_tls: self.is_tls, + } + } } -impl Deref for MySqlStream { - type Target = BufStream>; +impl Deref for MySqlStream { + type Target = BufferedSocket; fn deref(&self) -> &Self::Target { - &self.stream + &self.socket } } -impl DerefMut for MySqlStream { +impl DerefMut for MySqlStream { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.stream + &mut self.socket } } diff --git a/sqlx-core/src/mysql/connection/tls.rs b/sqlx-core/src/mysql/connection/tls.rs index 468b638fa8..d1f13792b7 100644 --- a/sqlx-core/src/mysql/connection/tls.rs +++ b/sqlx-core/src/mysql/connection/tls.rs @@ -1,39 +1,72 @@ use crate::error::Error; -use crate::mysql::connection::MySqlStream; +use crate::mysql::collation::{CharSet, Collation}; +use crate::mysql::connection::{MySqlStream, Waiting}; use crate::mysql::protocol::connect::SslRequest; use crate::mysql::protocol::Capabilities; use crate::mysql::{MySqlConnectOptions, MySqlSslMode}; +use crate::net::tls::TlsConfig; +use crate::net::{tls, BufferedSocket, Socket, WithSocket}; +use std::collections::VecDeque; -pub(super) async fn maybe_upgrade( - stream: &mut MySqlStream, +struct MapStream { + server_version: (u16, u16, u16), + capabilities: Capabilities, + sequence_id: u8, + waiting: VecDeque, + charset: CharSet, + collation: Collation, +} + +pub(super) async fn maybe_upgrade( + mut stream: MySqlStream, options: &MySqlConnectOptions, -) -> Result<(), Error> { +) -> Result { + let server_supports_tls = stream.capabilities.contains(Capabilities::SSL); + + if matches!(options.ssl_mode, MySqlSslMode::Disabled) || !tls::available() { + // remove the SSL capability if SSL has been explicitly disabled + stream.capabilities.remove(Capabilities::SSL); + } + // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS match options.ssl_mode { - MySqlSslMode::Disabled => {} + MySqlSslMode::Disabled => return Ok(stream.boxed_socket()), MySqlSslMode::Preferred => { - // try upgrade, but its okay if we fail - upgrade(stream, options).await?; + if !tls::available() { + // Client doesn't support TLS + log::debug!("not performing TLS upgrade: TLS support not compiled in"); + return Ok(stream.boxed_socket()); + } + + if !server_supports_tls { + // Server doesn't support TLS + log::debug!("not performing TLS upgrade: unsupported by server"); + return Ok(stream.boxed_socket()); + } } MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => { - if !upgrade(stream, options).await? { + tls::error_if_unavailable()?; + + if !server_supports_tls { // upgrade failed, die return Err(Error::Tls("server does not support TLS".into())); } } } - Ok(()) -} - -async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Result { - if !stream.capabilities.contains(Capabilities::SSL) { - // server does not support TLS - return Ok(false); - } + let tls_config = TlsConfig { + accept_invalid_certs: !matches!( + options.ssl_mode, + MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity + ), + accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity), + hostname: &options.host, + root_cert_path: options.ssl_ca.as_ref(), + }; + // Request TLS upgrade stream.write_packet(SslRequest { max_packet_size: super::MAX_PACKET_SIZE, collation: stream.collation as u8, @@ -41,20 +74,34 @@ async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Res stream.flush().await?; - let accept_invalid_certs = !matches!( - options.ssl_mode, - MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity - ); - let accept_invalid_host_names = !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity); - - stream - .upgrade( - &options.host, - accept_invalid_certs, - accept_invalid_host_names, - options.ssl_ca.as_ref(), - ) - .await?; - - Ok(true) + tls::handshake( + stream.socket.into_inner(), + tls_config, + MapStream { + server_version: stream.server_version, + capabilities: stream.capabilities, + sequence_id: stream.sequence_id, + waiting: stream.waiting, + charset: stream.charset, + collation: stream.collation, + }, + ) + .await +} + +impl WithSocket for MapStream { + type Output = MySqlStream; + + fn with_socket(self, socket: S) -> Self::Output { + MySqlStream { + socket: BufferedSocket::new(Box::new(socket)), + server_version: self.server_version, + capabilities: self.capabilities, + sequence_id: self.sequence_id, + waiting: self.waiting, + charset: self.charset, + collation: self.collation, + is_tls: true, + } + } } diff --git a/sqlx-core/src/mysql/options/mod.rs b/sqlx-core/src/mysql/options/mod.rs index 5d152c3869..d0959579d1 100644 --- a/sqlx-core/src/mysql/options/mod.rs +++ b/sqlx-core/src/mysql/options/mod.rs @@ -4,7 +4,7 @@ mod connect; mod parse; mod ssl_mode; -use crate::{connection::LogSettings, net::CertificateInput}; +use crate::{connection::LogSettings, net::tls::CertificateInput}; pub use ssl_mode::MySqlSslMode; /// Options and flags which can be used to configure a MySQL connection. @@ -35,8 +35,8 @@ pub use ssl_mode::MySqlSslMode; /// # use sqlx_core::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; /// # /// # fn main() { -/// # #[cfg(feature = "_rt-async-std")] -/// # sqlx_rt::async_std::task::block_on::<_, Result<(), Error>>(async move { +/// # #[cfg(feature = "_rt")] +/// # sqlx::__rt::test_block_on(async move { /// // URL connection string /// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?; /// @@ -47,7 +47,7 @@ pub use ssl_mode::MySqlSslMode; /// .password("password") /// .database("db") /// .connect().await?; -/// # Ok(()) +/// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// # } /// ``` diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index 429c5f6c44..3c75f32c92 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -1,17 +1,4 @@ mod socket; -mod tls; +pub mod tls; -pub use socket::Socket; -pub use tls::{CertificateInput, MaybeTlsStream}; - -#[cfg(feature = "_rt-async-std")] -type PollReadBuf<'a> = [u8]; - -#[cfg(feature = "_rt-tokio")] -type PollReadBuf<'a> = sqlx_rt::ReadBuf<'a>; - -#[cfg(feature = "_rt-async-std")] -type PollReadOut = usize; - -#[cfg(feature = "_rt-tokio")] -type PollReadOut = (); +pub use socket::{connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket}; diff --git a/sqlx-core/src/net/socket.rs b/sqlx-core/src/net/socket.rs deleted file mode 100644 index 622a1a22ce..0000000000 --- a/sqlx-core/src/net/socket.rs +++ /dev/null @@ -1,134 +0,0 @@ -#![allow(dead_code)] - -use std::io; -use std::net::SocketAddr; -use std::path::Path; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use sqlx_rt::{AsyncRead, AsyncWrite, TcpStream}; - -#[derive(Debug)] -pub enum Socket { - Tcp(TcpStream), - - #[cfg(unix)] - Unix(sqlx_rt::UnixStream), -} - -impl Socket { - pub async fn connect_tcp(host: &str, port: u16) -> io::Result { - // Trim square brackets from host if it's an IPv6 address as the `url` crate doesn't do that. - TcpStream::connect((host.trim_matches(|c| c == '[' || c == ']'), port)) - .await - .map(Socket::Tcp) - } - - #[cfg(unix)] - pub async fn connect_uds(path: impl AsRef) -> io::Result { - sqlx_rt::UnixStream::connect(path.as_ref()) - .await - .map(Socket::Unix) - } - - pub fn local_addr(&self) -> Option { - match self { - Self::Tcp(tcp) => tcp.local_addr().ok(), - #[cfg(unix)] - Self::Unix(_) => None, - } - } - - #[cfg(not(unix))] - pub async fn connect_uds(_: impl AsRef) -> io::Result { - Err(io::Error::new( - io::ErrorKind::Other, - "Unix domain sockets are not supported outside Unix platforms.", - )) - } - - pub async fn shutdown(&mut self) -> io::Result<()> { - #[cfg(feature = "_rt-async-std")] - { - use std::net::Shutdown; - - match self { - Socket::Tcp(s) => s.shutdown(Shutdown::Both), - - #[cfg(unix)] - Socket::Unix(s) => s.shutdown(Shutdown::Both), - } - } - - #[cfg(feature = "_rt-tokio")] - { - use sqlx_rt::AsyncWriteExt; - - match self { - Socket::Tcp(s) => s.shutdown().await, - - #[cfg(unix)] - Socket::Unix(s) => s.shutdown().await, - } - } - } -} - -impl AsyncRead for Socket { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut super::PollReadBuf<'_>, - ) -> Poll> { - match &mut *self { - Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf), - - #[cfg(unix)] - Socket::Unix(s) => Pin::new(s).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for Socket { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - Socket::Tcp(s) => Pin::new(s).poll_write(cx, buf), - - #[cfg(unix)] - Socket::Unix(s) => Pin::new(s).poll_write(cx, buf), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Socket::Tcp(s) => Pin::new(s).poll_flush(cx), - - #[cfg(unix)] - Socket::Unix(s) => Pin::new(s).poll_flush(cx), - } - } - - #[cfg(feature = "_rt-tokio")] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Socket::Tcp(s) => Pin::new(s).poll_shutdown(cx), - - #[cfg(unix)] - Socket::Unix(s) => Pin::new(s).poll_shutdown(cx), - } - } - - #[cfg(feature = "_rt-async-std")] - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Socket::Tcp(s) => Pin::new(s).poll_close(cx), - - #[cfg(unix)] - Socket::Unix(s) => Pin::new(s).poll_close(cx), - } - } -} diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs new file mode 100644 index 0000000000..dc05c87863 --- /dev/null +++ b/sqlx-core/src/net/socket/buffered.rs @@ -0,0 +1,234 @@ +use crate::net::Socket; +use bytes::BytesMut; +use std::io; + +use crate::error::Error; + +use crate::io::{Decode, Encode}; + +// Tokio, async-std, and std all use this as the default capacity for their buffered I/O. +const DEFAULT_BUF_SIZE: usize = 8192; + +pub struct BufferedSocket { + socket: S, + write_buf: WriteBuffer, + read_buf: ReadBuffer, +} + +pub struct WriteBuffer { + buf: Vec, + bytes_written: usize, + bytes_flushed: usize, +} + +pub struct ReadBuffer { + read: BytesMut, + available: BytesMut, +} + +impl BufferedSocket { + pub fn new(socket: S) -> Self + where + S: Sized, + { + BufferedSocket { + socket, + write_buf: WriteBuffer { + buf: Vec::with_capacity(DEFAULT_BUF_SIZE), + bytes_written: 0, + bytes_flushed: 0, + }, + read_buf: ReadBuffer { + read: BytesMut::new(), + available: BytesMut::with_capacity(DEFAULT_BUF_SIZE), + }, + } + } + + pub async fn read_buffered(&mut self, len: usize) -> io::Result { + while self.read_buf.read.len() < len { + self.read_buf.reserve(len); + + let read = self.socket.read(&mut self.read_buf.available).await?; + + if read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!( + "expected to read {} bytes, got {} bytes at EOF", + len, + self.read_buf.read.len() + ), + )); + } + + self.read_buf.advance(read); + } + + Ok(self.read_buf.drain(len)) + } + + pub fn write_buffer(&self) -> &WriteBuffer { + &self.write_buf + } + + pub fn write_buffer_mut(&mut self) -> &mut WriteBuffer { + &mut self.write_buf + } + + pub async fn read<'de, T>(&mut self, byte_len: usize) -> Result + where + T: Decode<'de, ()>, + { + self.read_with(byte_len, ()).await + } + + pub async fn read_with<'de, T, C>(&mut self, byte_len: usize, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.read_buffered(byte_len).await?.freeze(), context) + } + + pub fn write<'en, T>(&mut self, value: T) + where + T: Encode<'en, ()>, + { + self.write_with(value, ()) + } + + pub fn write_with<'en, T, C>(&mut self, value: T, context: C) + where + T: Encode<'en, C>, + { + value.encode_with(self.write_buf.buf_mut(), context); + self.write_buf.bytes_written = self.write_buf.buf.len(); + self.write_buf.sanity_check(); + } + + pub async fn flush(&mut self) -> io::Result<()> { + while !self.write_buf.is_empty() { + let written = self.socket.write(self.write_buf.get()).await?; + self.write_buf.consume(written); + self.write_buf.sanity_check(); + } + + self.socket.flush().await?; + + Ok(()) + } + + pub async fn shutdown(&mut self) -> io::Result<()> { + self.flush().await?; + self.socket.shutdown().await + } + + pub fn into_inner(self) -> S { + self.socket + } + + pub fn boxed(self) -> BufferedSocket> { + BufferedSocket { + socket: Box::new(self.socket), + write_buf: self.write_buf, + read_buf: self.read_buf, + } + } +} + +impl WriteBuffer { + fn sanity_check(&self) { + assert_ne!(self.buf.capacity(), 0); + assert!(self.bytes_written <= self.buf.len()); + assert!(self.bytes_flushed <= self.bytes_written); + } + + pub fn buf_mut(&mut self) -> &mut Vec { + self.buf.truncate(self.bytes_written); + self.sanity_check(); + &mut self.buf + } + + pub fn init_remaining_mut(&mut self) -> &mut [u8] { + self.buf.resize(self.buf.capacity(), 0); + self.sanity_check(); + &mut self.buf[self.bytes_written..] + } + + pub fn put_slice(&mut self, slice: &[u8]) { + // If we already have an initialized area that can fit the slice, + // don't change `self.buf.len()` + if let Some(dest) = self.buf[self.bytes_written..].get_mut(..slice.len()) { + dest.copy_from_slice(slice); + } else { + self.buf.truncate(self.bytes_written); + self.buf.extend_from_slice(slice); + } + + self.sanity_check(); + } + + pub fn advance(&mut self, amt: usize) { + let new_bytes_written = self + .bytes_written + .checked_add(amt) + .expect("self.bytes_written + amt overflowed"); + + assert!(new_bytes_written <= self.buf.len()); + + self.bytes_written = new_bytes_written; + + self.sanity_check(); + } + + pub fn is_empty(&self) -> bool { + self.bytes_flushed >= self.bytes_written + } + + pub fn is_full(&self) -> bool { + self.bytes_written == self.buf.len() + } + + pub fn get(&self) -> &[u8] { + &self.buf[self.bytes_flushed..self.bytes_written] + } + + pub fn get_mut(&mut self) -> &mut [u8] { + &mut self.buf[self.bytes_flushed..self.bytes_written] + } + + fn consume(&mut self, amt: usize) { + let new_bytes_flushed = self + .bytes_flushed + .checked_add(amt) + .expect("self.bytes_flushed + amt overflowed"); + + assert!(new_bytes_flushed <= self.bytes_written); + + self.bytes_flushed = new_bytes_flushed; + + if self.bytes_flushed == self.bytes_written { + // Reset cursors to zero if we've consumed the whole buffer + self.bytes_flushed = 0; + self.bytes_written = 0; + } + + self.sanity_check(); + } +} + +impl ReadBuffer { + fn reserve(&mut self, amt: usize) { + if let Some(additional) = amt.checked_sub(self.available.capacity()) { + self.available.reserve(additional); + } + } + + fn advance(&mut self, amt: usize) { + self.read.unsplit(self.available.split_to(amt)); + } + + fn drain(&mut self, amt: usize) -> BytesMut { + self.read.split_to(amt) + } +} diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs new file mode 100644 index 0000000000..cd7f24d780 --- /dev/null +++ b/sqlx-core/src/net/socket/mod.rs @@ -0,0 +1,259 @@ +use std::future::Future; +use std::io; +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::BufMut; +use futures_core::ready; + +pub use buffered::{BufferedSocket, WriteBuffer}; + +use crate::io::ReadBuf; + +mod buffered; + +pub trait Socket: Send + Sync + Unpin + 'static { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result; + + fn try_write(&mut self, buf: &[u8]) -> io::Result; + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + + fn poll_flush(&mut self, _cx: &mut Context<'_>) -> Poll> { + // `flush()` is a no-op for TCP/UDS + Poll::Ready(Ok(())) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll>; + + fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B> + where + Self: Sized, + { + Read { socket: self, buf } + } + + fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self> + where + Self: Sized, + { + Write { socket: self, buf } + } + + fn flush(&mut self) -> Flush<'_, Self> + where + Self: Sized, + { + Flush { socket: self } + } + + fn shutdown(&mut self) -> Shutdown<'_, Self> + where + Self: Sized, + { + Shutdown { socket: self } + } +} + +pub struct Read<'a, S: ?Sized, B> { + socket: &'a mut S, + buf: &'a mut B, +} + +impl<'a, S: ?Sized, B> Future for Read<'a, S, B> +where + S: Socket, + B: ReadBuf, +{ + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + + while this.buf.has_remaining_mut() { + match this.socket.try_read(&mut *this.buf) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + ready!(this.socket.poll_read_ready(cx))?; + } + ready => return Poll::Ready(ready), + } + } + + Poll::Ready(Ok(0)) + } +} + +pub struct Write<'a, S: ?Sized> { + socket: &'a mut S, + buf: &'a [u8], +} + +impl<'a, S: ?Sized> Future for Write<'a, S> +where + S: Socket, +{ + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + + while !this.buf.is_empty() { + match this.socket.try_write(&mut this.buf) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + ready!(this.socket.poll_write_ready(cx))?; + } + ready => return Poll::Ready(ready), + } + } + + Poll::Ready(Ok(0)) + } +} + +pub struct Flush<'a, S: ?Sized> { + socket: &'a mut S, +} + +impl<'a, S: Socket + ?Sized> Future for Flush<'a, S> { + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.socket.poll_flush(cx) + } +} + +pub struct Shutdown<'a, S: ?Sized> { + socket: &'a mut S, +} + +impl<'a, S: ?Sized> Future for Shutdown<'a, S> +where + S: Socket, +{ + type Output = io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.socket.poll_shutdown(cx) + } +} + +pub trait WithSocket { + type Output; + + fn with_socket(self, socket: S) -> Self::Output; +} + +pub struct SocketIntoBox; + +impl WithSocket for SocketIntoBox { + type Output = Box; + + fn with_socket(self, socket: S) -> Self::Output { + Box::new(socket) + } +} + +impl Socket for Box { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + (**self).try_read(buf) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + (**self).try_write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (**self).poll_read_ready(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (**self).poll_write_ready(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + (**self).poll_shutdown(cx) + } +} + +pub async fn connect_tcp( + host: &str, + port: u16, + with_socket: Ws, +) -> crate::Result { + // IPv6 addresses in URLs will be wrapped in brackets and the `url` crate doesn't trim those. + let host = host.trim_matches(&['[', ']'][..]); + + #[cfg(feature = "_rt-tokio")] + if crate::rt::rt_tokio::available() { + use tokio::net::TcpStream; + + let stream = TcpStream::connect((host, port)).await?; + + return Ok(with_socket.with_socket(stream)); + } + + #[cfg(feature = "_rt-async-std")] + { + use async_io::Async; + use async_std::net::ToSocketAddrs; + use std::net::TcpStream; + + let socket_addr = (host, port) + .to_socket_addrs() + .await? + .next() + .expect("BUG: to_socket_addrs() should have returned at least one result"); + + let stream = Async::::connect(socket_addr).await?; + + return Ok(with_socket.with_socket(stream)); + } + + #[cfg(not(feature = "_rt-async-std"))] + { + crate::rt::missing_rt((host, port, with_socket)) + } +} + +/// Connect a Unix Domain Socket at the given path. +/// +/// Returns an error if Unix Domain Sockets are not supported on this platform. +pub async fn connect_uds, Ws: WithSocket>( + path: P, + with_socket: Ws, +) -> crate::Result { + if cfg!(not(unix)) { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "Unix domain sockets are not supported on this platform", + ) + .into()); + } + + #[cfg(all(unix, feature = "_rt-tokio"))] + if crate::rt::rt_tokio::available() { + use tokio::net::UnixStream; + + let stream = UnixStream::connect(path).await?; + + return Ok(with_socket.with_socket(stream)); + } + + #[cfg(all(unix, feature = "_rt-async-std"))] + { + use async_io::Async; + use std::os::unix::net::UnixStream; + + let stream = Async::::connect(path).await?; + + return Ok(with_socket.with_socket(stream)); + } + + #[cfg(not(feature = "_rt-async-std"))] + { + crate::rt::missing_rt((path, with_socket)) + } +} diff --git a/sqlx-core/src/net/tls/mod.rs b/sqlx-core/src/net/tls/mod.rs index 85e5dda7c1..3fae0ecaed 100644 --- a/sqlx-core/src/net/tls/mod.rs +++ b/sqlx-core/src/net/tls/mod.rs @@ -1,15 +1,18 @@ #![allow(dead_code)] -use std::io; -use std::ops::{Deref, DerefMut}; use std::path::PathBuf; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream}; use crate::error::Error; -use std::mem::replace; +use crate::net::socket::WithSocket; +use crate::net::Socket; + +#[cfg(feature = "_tls-rustls")] +mod tls_rustls; + +#[cfg(feature = "_tls-native-tls")] +mod tls_native_tls; + +mod util; /// X.509 Certificate input, either a file path or a PEM encoded inline certificate(s). #[derive(Clone, Debug)] @@ -36,7 +39,7 @@ impl From for CertificateInput { impl CertificateInput { async fn data(&self) -> Result, std::io::Error> { - use sqlx_rt::fs; + use crate::fs; match self { CertificateInput::Inline(v) => Ok(v.clone()), CertificateInput::File(path) => fs::read(path).await, @@ -53,210 +56,46 @@ impl std::fmt::Display for CertificateInput { } } -#[cfg(feature = "_tls-rustls")] -mod rustls; - -pub enum MaybeTlsStream -where - S: AsyncRead + AsyncWrite + Unpin, -{ - Raw(S), - Tls(TlsStream), - Upgrading, +pub struct TlsConfig<'a> { + pub accept_invalid_certs: bool, + pub accept_invalid_hostnames: bool, + pub hostname: &'a str, + pub root_cert_path: Option<&'a CertificateInput>, } -impl MaybeTlsStream +pub async fn handshake( + socket: S, + config: TlsConfig<'_>, + with_socket: Ws, +) -> crate::Result where - S: AsyncRead + AsyncWrite + Unpin, + S: Socket, + Ws: WithSocket, { - #[inline] - pub fn is_tls(&self) -> bool { - matches!(self, Self::Tls(_)) - } - - pub async fn upgrade( - &mut self, - host: &str, - accept_invalid_certs: bool, - accept_invalid_hostnames: bool, - root_cert_path: Option<&CertificateInput>, - ) -> Result<(), Error> { - let connector = configure_tls_connector( - accept_invalid_certs, - accept_invalid_hostnames, - root_cert_path, - ) - .await?; - - let stream = match replace(self, MaybeTlsStream::Upgrading) { - MaybeTlsStream::Raw(stream) => stream, - - MaybeTlsStream::Tls(_) => { - // ignore upgrade, we are already a TLS connection - return Ok(()); - } - - MaybeTlsStream::Upgrading => { - // we previously failed to upgrade and now hold no connection - // this should only happen from an internal misuse of this method - return Err(Error::Io(io::ErrorKind::ConnectionAborted.into())); - } - }; - - #[cfg(feature = "_tls-rustls")] - let host = ::rustls::ServerName::try_from(host).map_err(|err| Error::Tls(err.into()))?; - - *self = MaybeTlsStream::Tls(connector.connect(host, stream).await?); - - Ok(()) - } -} - -#[cfg(feature = "_tls-native-tls")] -async fn configure_tls_connector( - accept_invalid_certs: bool, - accept_invalid_hostnames: bool, - root_cert_path: Option<&CertificateInput>, -) -> Result { - use sqlx_rt::native_tls::{Certificate, TlsConnector}; - - let mut builder = TlsConnector::builder(); - builder - .danger_accept_invalid_certs(accept_invalid_certs) - .danger_accept_invalid_hostnames(accept_invalid_hostnames); - - if !accept_invalid_certs { - if let Some(ca) = root_cert_path { - let data = ca.data().await?; - let cert = Certificate::from_pem(&data)?; - - builder.add_root_certificate(cert); - } - } + #[cfg(feature = "_tls-native-tls")] + return Ok(with_socket.with_socket(tls_native_tls::handshake(socket, config).await?)); - #[cfg(not(feature = "_rt-async-std"))] - let connector = builder.build()?.into(); + #[cfg(feature = "_tls-rustls")] + return Ok(with_socket.with_socket(tls_rustls::handshake(socket, config).await?)); - #[cfg(feature = "_rt-async-std")] - let connector = builder.into(); - - Ok(connector) -} - -#[cfg(feature = "_tls-rustls")] -use self::rustls::configure_tls_connector; - -impl AsyncRead for MaybeTlsStream -where - S: Unpin + AsyncWrite + AsyncRead, -{ - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut super::PollReadBuf<'_>, - ) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + #[cfg(not(any(feature = "_tls-native-tls", feature = "_tls-rustls")))] + { + drop((socket, config, with_socket)); + panic!("one of the `runtime-*-native-tls` or `runtime-*-rustls` features must be enabled") } } -impl AsyncWrite for MaybeTlsStream -where - S: Unpin + AsyncWrite + AsyncRead, -{ - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } - } - - #[cfg(feature = "_rt-tokio")] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } - } - - #[cfg(feature = "_rt-async-std")] - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } - } +pub fn available() -> bool { + cfg!(any(feature = "_tls-native-tls", feature = "_tls-rustls")) } -impl Deref for MaybeTlsStream -where - S: Unpin + AsyncWrite + AsyncRead, -{ - type Target = S; - - fn deref(&self) -> &Self::Target { - match self { - MaybeTlsStream::Raw(s) => s, - - #[cfg(feature = "_tls-rustls")] - MaybeTlsStream::Tls(s) => s.get_ref().0, - - #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))] - MaybeTlsStream::Tls(s) => s.get_ref(), - - #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))] - MaybeTlsStream::Tls(s) => s.get_ref().get_ref().get_ref(), - - MaybeTlsStream::Upgrading => { - panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted)) - } - } +pub fn error_if_unavailable() -> crate::Result<()> { + if !available() { + return Err(Error::tls( + "TLS upgrade required by connect options \ + but SQLx was built without TLS support enabled", + )); } -} -impl DerefMut for MaybeTlsStream -where - S: Unpin + AsyncWrite + AsyncRead, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - MaybeTlsStream::Raw(s) => s, - - #[cfg(feature = "_tls-rustls")] - MaybeTlsStream::Tls(s) => s.get_mut().0, - - #[cfg(all(feature = "_rt-async-std", feature = "_tls-native-tls"))] - MaybeTlsStream::Tls(s) => s.get_mut(), - - #[cfg(all(not(feature = "_rt-async-std"), feature = "_tls-native-tls"))] - MaybeTlsStream::Tls(s) => s.get_mut().get_mut().get_mut(), - - MaybeTlsStream::Upgrading => { - panic!("{}", io::Error::from(io::ErrorKind::ConnectionAborted)) - } - } - } + Ok(()) } diff --git a/sqlx-core/src/net/tls/rustls.rs b/sqlx-core/src/net/tls/rustls.rs deleted file mode 100644 index 2ad958b0d2..0000000000 --- a/sqlx-core/src/net/tls/rustls.rs +++ /dev/null @@ -1,108 +0,0 @@ -use crate::net::CertificateInput; -use rustls::{ - client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier}, - ClientConfig, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName, -}; -use std::io::Cursor; -use std::sync::Arc; -use std::time::SystemTime; - -use crate::error::Error; - -pub async fn configure_tls_connector( - accept_invalid_certs: bool, - accept_invalid_hostnames: bool, - root_cert_path: Option<&CertificateInput>, -) -> Result { - let config = ClientConfig::builder().with_safe_defaults(); - - let config = if accept_invalid_certs { - config - .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier)) - .with_no_client_auth() - } else { - let mut cert_store = RootCertStore::empty(); - cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - if let Some(ca) = root_cert_path { - let data = ca.data().await?; - let mut cursor = Cursor::new(data); - - for cert in rustls_pemfile::certs(&mut cursor) - .map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))? - { - cert_store - .add(&rustls::Certificate(cert)) - .map_err(|err| Error::Tls(err.into()))?; - } - } - - if accept_invalid_hostnames { - let verifier = WebPkiVerifier::new(cert_store, None); - - config - .with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier })) - .with_no_client_auth() - } else { - config - .with_root_certificates(cert_store) - .with_no_client_auth() - } - }; - - Ok(Arc::new(config).into()) -} - -struct DummyTlsVerifier; - -impl ServerCertVerifier for DummyTlsVerifier { - fn verify_server_cert( - &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: SystemTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } -} - -pub struct NoHostnameTlsVerifier { - verifier: WebPkiVerifier, -} - -impl ServerCertVerifier for NoHostnameTlsVerifier { - fn verify_server_cert( - &self, - end_entity: &rustls::Certificate, - intermediates: &[rustls::Certificate], - server_name: &ServerName, - scts: &mut dyn Iterator, - ocsp_response: &[u8], - now: SystemTime, - ) -> Result { - match self.verifier.verify_server_cert( - end_entity, - intermediates, - server_name, - scts, - ocsp_response, - now, - ) { - Err(TlsError::InvalidCertificateData(reason)) - if reason.contains("CertNotValidForName") => - { - Ok(ServerCertVerified::assertion()) - } - res => res, - } - } -} diff --git a/sqlx-core/src/net/tls/tls_native_tls.rs b/sqlx-core/src/net/tls/tls_native_tls.rs new file mode 100644 index 0000000000..5405bac3c2 --- /dev/null +++ b/sqlx-core/src/net/tls/tls_native_tls.rs @@ -0,0 +1,82 @@ +use std::io::{self, Read, Write}; + +use crate::io::ReadBuf; +use crate::net::tls::util::StdSocket; +use crate::net::tls::TlsConfig; +use crate::net::Socket; +use crate::Error; + +use native_tls::HandshakeError; +use std::task::{Context, Poll}; + +pub struct NativeTlsSocket { + stream: native_tls::TlsStream>, +} + +impl Socket for NativeTlsSocket { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + self.stream.read(buf.init_mut()) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + self.stream.write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.stream.get_mut().poll_ready(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.stream.get_mut().poll_ready(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.stream.shutdown() { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => self.stream.get_mut().poll_ready(cx), + ready => Poll::Ready(ready), + } + } +} + +/// DEPRECATED: this should never have been public. +impl From for Error { + fn from(e: native_tls::Error) -> Self { + Error::Tls(Box::new(e)) + } +} + +pub async fn handshake( + socket: S, + config: TlsConfig<'_>, +) -> crate::Result> { + let mut builder = native_tls::TlsConnector::builder(); + + builder + .danger_accept_invalid_certs(config.accept_invalid_certs) + .danger_accept_invalid_hostnames(config.accept_invalid_hostnames); + + if let Some(root_cert_path) = config.root_cert_path { + let data = root_cert_path.data().await?; + builder.add_root_certificate(native_tls::Certificate::from_pem(&data)?); + } + + let connector = builder.build()?; + + let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) { + Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }), + Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)), + Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake, + }; + + loop { + mid_handshake.get_mut().ready().await?; + + match mid_handshake.handshake() { + Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }), + Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)), + Err(HandshakeError::WouldBlock(mid_handshake_)) => { + mid_handshake = mid_handshake_; + } + } + } +} diff --git a/sqlx-core/src/net/tls/tls_rustls.rs b/sqlx-core/src/net/tls/tls_rustls.rs new file mode 100644 index 0000000000..230e03527f --- /dev/null +++ b/sqlx-core/src/net/tls/tls_rustls.rs @@ -0,0 +1,184 @@ +use futures_util::future; +use std::io; +use std::io::{Cursor, Read, Write}; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::SystemTime; + +use rustls::{ + client::{ServerCertVerified, ServerCertVerifier, WebPkiVerifier}, + ClientConfig, ClientConnection, Error as TlsError, OwnedTrustAnchor, RootCertStore, ServerName, +}; + +use crate::error::Error; +use crate::io::ReadBuf; +use crate::net::tls::util::StdSocket; +use crate::net::tls::TlsConfig; +use crate::net::Socket; + +pub struct RustlsSocket { + inner: StdSocket, + state: ClientConnection, + close_notify_sent: bool, +} + +impl RustlsSocket { + fn poll_complete_io(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state.complete_io(&mut self.inner) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + futures_util::ready!(self.inner.poll_ready(cx))?; + } + ready => return Poll::Ready(ready.map(|_| ())), + } + } + } + + async fn complete_io(&mut self) -> io::Result<()> { + future::poll_fn(|cx| self.poll_complete_io(cx)).await + } +} + +impl Socket for RustlsSocket { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + self.state.reader().read(buf.init_mut()) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + match self.state.writer().write(buf) { + // Returns a zero-length write when the buffer is full. + Ok(0) => Err(io::ErrorKind::WouldBlock.into()), + other => return other, + } + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_complete_io(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_complete_io(cx) + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_complete_io(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.close_notify_sent { + self.state.send_close_notify(); + self.close_notify_sent = true; + } + + futures_util::ready!(self.poll_complete_io(cx))?; + self.inner.socket.poll_shutdown(cx) + } +} + +pub async fn handshake(socket: S, tls_config: TlsConfig<'_>) -> Result, Error> +where + S: Socket, +{ + let config = ClientConfig::builder().with_safe_defaults(); + + let config = if tls_config.accept_invalid_certs { + config + .with_custom_certificate_verifier(Arc::new(DummyTlsVerifier)) + .with_no_client_auth() + } else { + let mut cert_store = RootCertStore::empty(); + cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + if let Some(ca) = tls_config.root_cert_path { + let data = ca.data().await?; + let mut cursor = Cursor::new(data); + + for cert in rustls_pemfile::certs(&mut cursor) + .map_err(|_| Error::Tls(format!("Invalid certificate {}", ca).into()))? + { + cert_store + .add(&rustls::Certificate(cert)) + .map_err(|err| Error::Tls(err.into()))?; + } + } + + if tls_config.accept_invalid_hostnames { + let verifier = WebPkiVerifier::new(cert_store, None); + + config + .with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier })) + .with_no_client_auth() + } else { + config + .with_root_certificates(cert_store) + .with_no_client_auth() + } + }; + + let host = rustls::ServerName::try_from(tls_config.hostname).map_err(Error::tls)?; + + let mut socket = RustlsSocket { + inner: StdSocket::new(socket), + state: ClientConnection::new(Arc::new(config), host).map_err(Error::tls)?, + close_notify_sent: false, + }; + + // Performs the TLS handshake or bails + socket.complete_io().await?; + + Ok(socket) +} + +struct DummyTlsVerifier; + +impl ServerCertVerifier for DummyTlsVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } +} + +pub struct NoHostnameTlsVerifier { + verifier: WebPkiVerifier, +} + +impl ServerCertVerifier for NoHostnameTlsVerifier { + fn verify_server_cert( + &self, + end_entity: &rustls::Certificate, + intermediates: &[rustls::Certificate], + server_name: &ServerName, + scts: &mut dyn Iterator, + ocsp_response: &[u8], + now: SystemTime, + ) -> Result { + match self.verifier.verify_server_cert( + end_entity, + intermediates, + server_name, + scts, + ocsp_response, + now, + ) { + Err(TlsError::InvalidCertificateData(reason)) + if reason.contains("CertNotValidForName") => + { + Ok(ServerCertVerified::assertion()) + } + res => res, + } + } +} diff --git a/sqlx-core/src/net/tls/util.rs b/sqlx-core/src/net/tls/util.rs new file mode 100644 index 0000000000..02a16ef5e1 --- /dev/null +++ b/sqlx-core/src/net/tls/util.rs @@ -0,0 +1,65 @@ +use crate::net::Socket; + +use std::io::{self, Read, Write}; +use std::task::{Context, Poll}; + +use futures_core::ready; +use futures_util::future; + +pub struct StdSocket { + pub socket: S, + wants_read: bool, + wants_write: bool, +} + +impl StdSocket { + pub fn new(socket: S) -> Self { + Self { + socket, + wants_read: false, + wants_write: false, + } + } + + pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.wants_write { + ready!(self.socket.poll_write_ready(cx))?; + self.wants_write = false; + } + + if self.wants_read { + ready!(self.socket.poll_read_ready(cx))?; + self.wants_read = false; + } + + Poll::Ready(Ok(())) + } + + pub async fn ready(&mut self) -> io::Result<()> { + future::poll_fn(|cx| self.poll_ready(cx)).await + } +} + +impl Read for StdSocket { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + self.wants_read = true; + let read = self.socket.try_read(&mut buf)?; + self.wants_read = false; + + Ok(read) + } +} + +impl Write for StdSocket { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.wants_write = true; + let written = self.socket.try_write(buf)?; + self.wants_write = false; + Ok(written) + } + + fn flush(&mut self) -> io::Result<()> { + // NOTE: TCP sockets and unix sockets are both no-ops for flushes + Ok(()) + } +} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 9c61547cbe..ade6326282 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::{Duration, Instant}; -use futures_intrusive::sync::SemaphoreReleaser; +use crate::sync::AsyncSemaphoreReleaser; use crate::connection::Connection; use crate::database::Database; @@ -134,13 +134,7 @@ impl Drop for PoolConnection { fn drop(&mut self) { // We still need to spawn a task to maintain `min_connections`. if self.live.is_some() || self.pool.options.min_connections > 0 { - #[cfg(not(feature = "_rt-async-std"))] - if let Ok(handle) = sqlx_rt::Handle::try_current() { - handle.spawn(self.return_to_pool()); - } - - #[cfg(feature = "_rt-async-std")] - sqlx_rt::spawn(self.return_to_pool()); + crate::rt::spawn(self.return_to_pool()); } } } @@ -288,7 +282,7 @@ impl Floating> { pub fn from_idle( idle: Idle, pool: Arc>, - permit: SemaphoreReleaser<'_>, + permit: AsyncSemaphoreReleaser<'_>, ) -> Self { Self { inner: idle, diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 7bfae7fc78..1d7d4ba647 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -6,7 +6,7 @@ use crate::error::Error; use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; use crossbeam_queue::ArrayQueue; -use futures_intrusive::sync::{Semaphore, SemaphoreReleaser}; +use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; use std::cmp; use std::future::Future; @@ -22,7 +22,7 @@ use std::time::{Duration, Instant}; pub(crate) struct PoolInner { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, - pub(super) semaphore: Semaphore, + pub(super) semaphore: AsyncSemaphore, pub(super) size: AtomicU32, pub(super) num_idle: AtomicUsize, is_closed: AtomicBool, @@ -49,7 +49,7 @@ impl PoolInner { let pool = Self { connect_options, idle_conns: ArrayQueue::new(capacity), - semaphore: Semaphore::new(options.fair, semaphore_capacity), + semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity), size: AtomicU32::new(0), num_idle: AtomicUsize::new(0), is_closed: AtomicBool::new(false), @@ -86,7 +86,7 @@ impl PoolInner { self.on_closed.notify(usize::MAX); async move { - for permits in 1..=self.options.max_connections as usize { + for permits in 1..=self.options.max_connections { // Close any currently idle connections in the pool. while let Some(idle) = self.idle_conns.pop() { let _ = idle.live.float((*self).clone()).close().await; @@ -112,7 +112,7 @@ impl PoolInner { /// /// If we steal a permit from the parent but *don't* open a connection, /// it should be returned to the parent. - async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> { + async fn acquire_permit<'a>(self: &'a Arc) -> Result, Error> { let parent = self .parent() // If we're already at the max size, we shouldn't try to steal from the parent. @@ -182,8 +182,8 @@ impl PoolInner { fn pop_idle<'a>( self: &'a Arc, - permit: SemaphoreReleaser<'a>, - ) -> Result>, SemaphoreReleaser<'a>> { + permit: AsyncSemaphoreReleaser<'a>, + ) -> Result>, AsyncSemaphoreReleaser<'a>> { if let Some(idle) = self.idle_conns.pop() { self.num_idle.fetch_sub(1, Ordering::AcqRel); Ok(Floating::from_idle(idle, (*self).clone(), permit)) @@ -211,8 +211,8 @@ impl PoolInner { /// Try to atomically increment the pool size for a new connection. pub(super) fn try_increment_size<'a>( self: &'a Arc, - permit: SemaphoreReleaser<'a>, - ) -> Result, SemaphoreReleaser<'a>> { + permit: AsyncSemaphoreReleaser<'a>, + ) -> Result, AsyncSemaphoreReleaser<'a>> { match self .size .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { @@ -233,7 +233,7 @@ impl PoolInner { let deadline = Instant::now() + self.options.acquire_timeout; - sqlx_rt::timeout( + crate::rt::timeout( self.options.acquire_timeout, async { loop { @@ -263,7 +263,7 @@ impl PoolInner { // If so, we're likely in the current-thread runtime if it's Tokio // and so we should yield to let any spawned release_to_pool() tasks // execute. - sqlx_rt::yield_now().await; + crate::rt::yield_now().await; continue; } }; @@ -294,7 +294,7 @@ impl PoolInner { // result here is `Result, TimeoutError>` // if this block does not return, sleep for the backoff timeout and try again - match sqlx_rt::timeout(timeout, self.connect_options.connect()).await { + match crate::rt::timeout(timeout, self.connect_options.connect()).await { // successfully established connection Ok(Ok(mut raw)) => { // See comment on `PoolOptions::after_connect` @@ -338,7 +338,7 @@ impl PoolInner { // If the connection is refused, wait in exponentially // increasing steps for the server to come up, // capped by a factor of the remaining time until the deadline - sqlx_rt::sleep(backoff).await; + crate::rt::sleep(backoff).await; backoff = cmp::min(backoff * 2, max_backoff); } } @@ -467,7 +467,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { (None, None) => { if pool.options.min_connections > 0 { - sqlx_rt::spawn(async move { + crate::rt::spawn(async move { pool.min_connections_maintenance(None).await; }); } @@ -476,7 +476,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { } }; - sqlx_rt::spawn(async move { + crate::rt::spawn(async move { // Immediately cancel this task if the pool is closed. let _ = pool .close_event() @@ -488,9 +488,9 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if let Some(duration) = next_run.checked_duration_since(Instant::now()) { // `async-std` doesn't have a `sleep_until()` - sqlx_rt::sleep(duration).await; + crate::rt::sleep(duration).await; } else { - sqlx_rt::yield_now().await; + crate::rt::yield_now().await; } // Don't run the reaper right away. @@ -544,7 +544,7 @@ impl DecrementSizeGuard { } } - pub fn from_permit(pool: Arc>, mut permit: SemaphoreReleaser<'_>) -> Self { + pub fn from_permit(pool: Arc>, permit: AsyncSemaphoreReleaser<'_>) -> Self { // here we effectively take ownership of the permit permit.disarm(); Self::new_permit(pool) diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs index cd163c5039..feb2c9c9e4 100644 --- a/sqlx-core/src/postgres/connection/establish.rs +++ b/sqlx-core/src/postgres/connection/establish.rs @@ -3,7 +3,7 @@ use crate::HashMap; use crate::common::StatementCache; use crate::error::Error; use crate::io::Decode; -use crate::postgres::connection::{sasl, stream::PgStream, tls}; +use crate::postgres::connection::{sasl, stream::PgStream}; use crate::postgres::message::{ Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, }; @@ -15,10 +15,8 @@ use crate::postgres::{PgConnectOptions, PgConnection}; impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { - let mut stream = PgStream::connect(options).await?; - // Upgrade to TLS if we were asked to and the server supports it - tls::maybe_upgrade(&mut stream, options).await?; + let mut stream = PgStream::connect(options).await?; // To begin a session, a frontend opens a connection to the server // and sends a startup message. diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index 325b565c3b..3252857414 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -73,7 +73,7 @@ impl PgConnection { // will return when the connection is ready for another query pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.stream.wbuf.is_empty() { + if !self.stream.write_buffer_mut().is_empty() { self.stream.flush().await?; } @@ -203,6 +203,6 @@ impl Connection for PgConnection { #[doc(hidden)] fn should_flush(&self) -> bool { - !self.stream.wbuf.is_empty() + !self.stream.write_buffer().is_empty() } } diff --git a/sqlx-core/src/postgres/connection/stream.rs b/sqlx-core/src/postgres/connection/stream.rs index 59b5289b8e..3e76da2f48 100644 --- a/sqlx-core/src/postgres/connection/stream.rs +++ b/sqlx-core/src/postgres/connection/stream.rs @@ -8,8 +8,9 @@ use futures_util::SinkExt; use log::Level; use crate::error::Error; -use crate::io::{BufStream, Decode, Encode}; -use crate::net::{MaybeTlsStream, Socket}; +use crate::io::{Decode, Encode}; +use crate::net::{self, BufferedSocket, Socket}; +use crate::postgres::connection::tls::MaybeUpgradeTls; use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity}; @@ -23,7 +24,9 @@ use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity}; // is fully prepared to receive queries pub struct PgStream { - inner: BufStream>, + // A trait object is okay here as the buffering amortizes the overhead of both the dynamic + // function call as well as the syscall. + inner: BufferedSocket>, // buffer of unreceived notification messages from `PUBLISH` // this is set when creating a PgListener and only written to if that listener is @@ -37,15 +40,15 @@ pub struct PgStream { impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { - let socket = match options.fetch_socket() { - Some(ref path) => Socket::connect_uds(path).await?, - None => Socket::connect_tcp(&options.host, options.port).await?, + let socket_future = match options.fetch_socket() { + Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, + None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, }; - let inner = BufStream::new(MaybeTlsStream::Raw(socket)); + let socket = socket_future.await?; Ok(Self { - inner, + inner: BufferedSocket::new(socket), notifications: None, parameter_statuses: BTreeMap::default(), server_version_num: None, @@ -57,7 +60,8 @@ impl PgStream { T: Encode<'en>, { self.write(message); - self.flush().await + self.flush().await?; + Ok(()) } // Expect a specific type and format @@ -171,7 +175,7 @@ impl PgStream { } impl Deref for PgStream { - type Target = BufStream>; + type Target = BufferedSocket>; #[inline] fn deref(&self) -> &Self::Target { diff --git a/sqlx-core/src/postgres/connection/tls.rs b/sqlx-core/src/postgres/connection/tls.rs index 0c780f401a..882b54ec52 100644 --- a/sqlx-core/src/postgres/connection/tls.rs +++ b/sqlx-core/src/postgres/connection/tls.rs @@ -1,78 +1,100 @@ -use bytes::Bytes; +use futures_core::future::BoxFuture; use crate::error::Error; -use crate::postgres::connection::stream::PgStream; +use crate::net::tls::{self, TlsConfig}; +use crate::net::{Socket, SocketIntoBox, WithSocket}; + use crate::postgres::message::SslRequest; use crate::postgres::{PgConnectOptions, PgSslMode}; -pub(super) async fn maybe_upgrade( - stream: &mut PgStream, +pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); + +impl<'a> WithSocket for MaybeUpgradeTls<'a> { + type Output = BoxFuture<'a, crate::Result>>; + + fn with_socket(self, socket: S) -> Self::Output { + Box::pin(maybe_upgrade(socket, self.0)) + } +} + +async fn maybe_upgrade( + mut socket: S, options: &PgConnectOptions, -) -> Result<(), Error> { +) -> Result, Error> { // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS match options.ssl_mode { // FIXME: Implement ALLOW - PgSslMode::Allow | PgSslMode::Disable => {} + PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)), PgSslMode::Prefer => { + if !tls::available() { + return Ok(Box::new(socket)); + } + // try upgrade, but its okay if we fail - upgrade(stream, options).await?; + if !request_upgrade(&mut socket, options).await? { + return Ok(Box::new(socket)); + } } PgSslMode::Require | PgSslMode::VerifyFull | PgSslMode::VerifyCa => { - if !upgrade(stream, options).await? { + tls::error_if_unavailable()?; + + if !request_upgrade(&mut socket, options).await? { // upgrade failed, die return Err(Error::Tls("server does not support TLS".into())); } } } - Ok(()) + let accept_invalid_certs = !matches!( + options.ssl_mode, + PgSslMode::VerifyCa | PgSslMode::VerifyFull + ); + let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); + + let config = TlsConfig { + accept_invalid_certs, + accept_invalid_hostnames, + hostname: &options.host, + root_cert_path: options.ssl_root_cert.as_ref(), + }; + + tls::handshake(socket, config, SocketIntoBox).await } -async fn upgrade(stream: &mut PgStream, options: &PgConnectOptions) -> Result { +async fn request_upgrade( + socket: &mut impl Socket, + _options: &PgConnectOptions, +) -> Result { // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 // To initiate an SSL-encrypted connection, the frontend initially sends an // SSLRequest message rather than a StartupMessage - stream.send(SslRequest).await?; + socket.write(SslRequest::BYTES).await?; // The server then responds with a single byte containing S or N, indicating that // it is willing or unwilling to perform SSL, respectively. - match stream.read::(1).await?[0] { + let mut response = [0u8]; + + socket.read(&mut &mut response[..]).await?; + + match response[0] { b'S' => { // The server is ready and willing to accept an SSL connection + Ok(true) } b'N' => { // The server is _unwilling_ to perform SSL - return Ok(false); + Ok(false) } - other => { - return Err(err_protocol!( - "unexpected response from SSLRequest: 0x{:02x}", - other - )); - } + other => Err(err_protocol!( + "unexpected response from SSLRequest: 0x{:02x}", + other + )), } - - let accept_invalid_certs = !matches!( - options.ssl_mode, - PgSslMode::VerifyCa | PgSslMode::VerifyFull - ); - let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); - - stream - .upgrade( - &options.host, - accept_invalid_certs, - accept_invalid_hostnames, - options.ssl_root_cert.as_ref(), - ) - .await?; - - Ok(true) } diff --git a/sqlx-core/src/postgres/copy.rs b/sqlx-core/src/postgres/copy.rs index 0bad775085..5047bddb67 100644 --- a/sqlx-core/src/postgres/copy.rs +++ b/sqlx-core/src/postgres/copy.rs @@ -1,16 +1,24 @@ +use std::borrow::Cow; +use std::ops::{Deref, DerefMut}; + +use bytes::{BufMut, Bytes}; +use futures_core::stream::BoxStream; + use crate::error::{Error, Result}; use crate::ext::async_stream::TryAsyncStream; +use crate::io::AsyncRead; use crate::pool::{Pool, PoolConnection}; use crate::postgres::connection::PgConnection; use crate::postgres::message::{ CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, }; use crate::postgres::Postgres; -use bytes::{BufMut, Bytes}; -use futures_core::stream::BoxStream; -use smallvec::alloc::borrow::Cow; -use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt}; -use std::ops::{Deref, DerefMut}; + +#[cfg(not(feature = "_rt-tokio"))] +use futures_util::io::AsyncReadExt; + +#[cfg(feature = "_rt-tokio")] +use tokio::io::AsyncReadExt; impl PgConnection { /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data @@ -172,8 +180,16 @@ impl> PgCopyIn { /// /// `source` will be read to the end. /// - /// ### Note + /// ### Note: Completion Step Required /// You must still call either [Self::finish] or [Self::abort] to complete the process. + /// + /// ### Note: Runtime Features + /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std` + /// depending on which runtime feature is used. + /// + /// The runtime features _used_ to be mutually exclusive, but are no longer. + /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version + /// takes precedent. pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing struct BufGuard<'s>(&'s mut Vec); @@ -189,46 +205,34 @@ impl> PgCopyIn { // flush any existing messages in the buffer and clear it conn.stream.flush().await?; - { - let buf_stream = &mut *conn.stream; - let stream = &mut buf_stream.stream; - - // ensures the buffer isn't left in an inconsistent state - let mut guard = BufGuard(&mut buf_stream.wbuf); - - let buf: &mut Vec = &mut guard.0; - buf.push(b'd'); // CopyData format code - buf.resize(5, 0); // reserve space for the length - - loop { - let read = match () { - // Tokio lets us read into the buffer without zeroing first - #[cfg(feature = "runtime-tokio")] - _ if buf.len() != buf.capacity() => { - // in case we have some data in the buffer, which can occur - // if the previous write did not fill the buffer - buf.truncate(5); - source.read_buf(buf).await? - } - _ => { - // should be a no-op unless len != capacity - buf.resize(buf.capacity(), 0); - source.read(&mut buf[5..]).await? - } - }; + loop { + let buf = conn.stream.write_buffer_mut(); + + // CopyData format code and reserved space for length + buf.put_slice(b"d\0\0\0\0"); + + let read = match () { + // Tokio lets us read into the buffer without zeroing first + #[cfg(feature = "_rt-tokio")] + _ => source.read_buf(buf.buf_mut()).await?, + #[cfg(not(feature = "_rt-tokio"))] + _ => source.read(buf.init_remaining_mut()).await?, + }; + + if read == 0 { + // This will end up sending an empty `CopyData` packet but that should be fine. + break; + } - if read == 0 { - break; - } + buf.advance(read); - let read32 = u32::try_from(read) - .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + // Write the length + let read32 = u32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; - (&mut buf[1..]).put_u32(read32 + 4); + (&mut buf.get_mut()[1..]).put_u32(read32 + 4); - stream.write_all(&buf[..read + 5]).await?; - stream.flush().await?; - } + conn.stream.flush().await?; } Ok(self) diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 1432ae6c06..f5ad29eb06 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -191,8 +191,8 @@ impl PgListener { /// # use sqlx_core::postgres::PgListener; /// # use sqlx_core::error::Error; /// # - /// # #[cfg(feature = "_rt-async-std")] - /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move { + /// # #[cfg(feature = "_rt")] + /// # sqlx::__rt::test_block_on(async move { /// # let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { /// // ask for next notification, re-connecting (transparently) if needed @@ -200,7 +200,7 @@ impl PgListener { /// /// // handle notification, do something interesting /// } - /// # Ok(()) + /// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn recv(&mut self) -> Result { @@ -222,8 +222,8 @@ impl PgListener { /// # use sqlx_core::postgres::PgListener; /// # use sqlx_core::error::Error; /// # - /// # #[cfg(feature = "_rt-async-std")] - /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move { + /// # #[cfg(feature = "_rt")] + /// # sqlx::__rt::test_block_on(async move { /// # let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { /// // start handling notifications, connecting if needed @@ -233,7 +233,7 @@ impl PgListener { /// /// // connection lost, do something interesting /// } - /// # Ok(()) + /// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn try_recv(&mut self) -> Result, Error> { @@ -321,13 +321,7 @@ impl Drop for PgListener { }; // Unregister any listeners before returning the connection to the pool. - #[cfg(not(feature = "_rt-async-std"))] - if let Ok(handle) = sqlx_rt::Handle::try_current() { - handle.spawn(fut); - } - - #[cfg(feature = "_rt-async-std")] - sqlx_rt::spawn(fut); + crate::rt::spawn(fut); } } } diff --git a/sqlx-core/src/postgres/message/ssl_request.rs b/sqlx-core/src/postgres/message/ssl_request.rs index 7740c0be6e..fa57faf064 100644 --- a/sqlx-core/src/postgres/message/ssl_request.rs +++ b/sqlx-core/src/postgres/message/ssl_request.rs @@ -2,6 +2,10 @@ use crate::io::Encode; pub struct SslRequest; +impl SslRequest { + pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; +} + impl Encode<'_> for SslRequest { #[inline] fn encode_with(&self, buf: &mut Vec, _: ()) { @@ -12,10 +16,8 @@ impl Encode<'_> for SslRequest { #[test] fn test_encode_ssl_request() { - const EXPECTED: &[u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; - let mut buf = Vec::new(); SslRequest.encode(&mut buf); - assert_eq!(buf, EXPECTED); + assert_eq!(buf, SslRequest::BYTES); } diff --git a/sqlx-core/src/postgres/options/mod.rs b/sqlx-core/src/postgres/options/mod.rs index e870da0943..9bca6629df 100644 --- a/sqlx-core/src/postgres/options/mod.rs +++ b/sqlx-core/src/postgres/options/mod.rs @@ -3,12 +3,14 @@ use std::env::var; use std::fmt::{Display, Write}; use std::path::{Path, PathBuf}; +pub use ssl_mode::PgSslMode; + +use crate::{connection::LogSettings, net::tls::CertificateInput}; + mod connect; mod parse; mod pgpass; mod ssl_mode; -use crate::{connection::LogSettings, net::CertificateInput}; -pub use ssl_mode::PgSslMode; /// Options and flags which can be used to configure a PostgreSQL connection. /// @@ -58,8 +60,8 @@ pub use ssl_mode::PgSslMode; /// # use sqlx_core::postgres::{PgConnectOptions, PgConnection, PgSslMode}; /// # /// # fn main() { -/// # #[cfg(feature = "_rt-async-std")] -/// # sqlx_rt::async_std::task::block_on::<_, Result<(), Error>>(async move { +/// # #[cfg(feature = "_rt")] +/// # sqlx::__rt::test_block_on(async move { /// // URL connection string /// let conn = PgConnection::connect("postgres://localhost/mydb").await?; /// @@ -71,7 +73,7 @@ pub use ssl_mode::PgSslMode; /// .password("secret-password") /// .ssl_mode(PgSslMode::Require) /// .connect().await?; -/// # Ok(()) +/// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// # } /// ``` diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index b3e30dc52c..2d7a1c9b15 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -395,7 +395,7 @@ where } // Make a SQL query from a statement. -pub(crate) fn query_statement<'q, DB>( +pub fn query_statement<'q, DB>( statement: &'q >::Statement, ) -> Query<'q, DB, >::Arguments> where @@ -410,7 +410,7 @@ where } // Make a SQL query from a statement, with the given arguments. -pub(crate) fn query_statement_with<'q, DB, A>( +pub fn query_statement_with<'q, DB, A>( statement: &'q >::Statement, arguments: A, ) -> Query<'q, DB, A> diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 7a15150cbc..43a611cf65 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -195,7 +195,7 @@ where } // Make a SQL query from a statement, that is mapped to a concrete type. -pub(crate) fn query_statement_as<'q, DB, O>( +pub fn query_statement_as<'q, DB, O>( statement: &'q >::Statement, ) -> QueryAs<'q, DB, O, >::Arguments> where @@ -209,7 +209,7 @@ where } // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete type. -pub(crate) fn query_statement_as_with<'q, DB, O, A>( +pub fn query_statement_as_with<'q, DB, O, A>( statement: &'q >::Statement, arguments: A, ) -> QueryAs<'q, DB, O, A> diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index 82e2c3070f..10dd4d7e9f 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -541,7 +541,7 @@ where } } -#[cfg(test)] +#[cfg(all(test, feature = "postgres"))] mod test { use crate::postgres::Postgres; diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 19a78287bc..197c527e56 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -188,7 +188,7 @@ where } // Make a SQL query from a statement, that is mapped to a concrete value. -pub(crate) fn query_statement_scalar<'q, DB, O>( +pub fn query_statement_scalar<'q, DB, O>( statement: &'q >::Statement, ) -> QueryScalar<'q, DB, O, >::Arguments> where @@ -201,7 +201,7 @@ where } // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete value. -pub(crate) fn query_statement_scalar_with<'q, DB, O, A>( +pub fn query_statement_scalar_with<'q, DB, O, A>( statement: &'q >::Statement, arguments: A, ) -> QueryScalar<'q, DB, O, A> diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs new file mode 100644 index 0000000000..c5de4d9cdf --- /dev/null +++ b/sqlx-core/src/rt/mod.rs @@ -0,0 +1,168 @@ +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +#[cfg(feature = "_rt-async-std")] +pub mod rt_async_std; + +#[cfg(feature = "_rt-tokio")] +pub mod rt_tokio; + +#[derive(Debug, thiserror::Error)] +#[error("operation timed out")] +pub struct TimeoutError(()); + +pub enum JoinHandle { + #[cfg(feature = "_rt-async-std")] + AsyncStd(async_std::task::JoinHandle), + #[cfg(feature = "_rt-tokio")] + Tokio(tokio::task::JoinHandle), + // `PhantomData` requires `T: Unpin` + _Phantom(PhantomData T>), +} + +#[track_caller] +pub async fn timeout(duration: Duration, f: F) -> Result { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::timeout(duration, f) + .await + .map_err(|_| TimeoutError(())); + } + + #[cfg(feature = "_rt-async-std")] + { + return async_std::future::timeout(duration, f) + .await + .map_err(|_| TimeoutError(())); + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt((duration, f)) +} + +#[track_caller] +pub async fn sleep(duration: Duration) { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::sleep(duration).await; + } + + #[cfg(feature = "_rt-async-std")] + { + return async_std::task::sleep(duration).await; + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt(duration) +} + +#[track_caller] +pub fn spawn(fut: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "_rt-tokio")] + if let Ok(handle) = tokio::runtime::Handle::try_current() { + return JoinHandle::Tokio(handle.spawn(fut)); + } + + #[cfg(feature = "_rt-async-std")] + { + return JoinHandle::AsyncStd(async_std::task::spawn(fut)); + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt(fut) +} + +#[track_caller] +pub fn spawn_blocking(f: F) -> JoinHandle +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + #[cfg(feature = "_rt-tokio")] + if let Ok(handle) = tokio::runtime::Handle::try_current() { + return JoinHandle::Tokio(handle.spawn_blocking(f)); + } + + #[cfg(feature = "_rt-async-std")] + { + return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f)); + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt(f) +} + +#[track_caller] +pub async fn yield_now() { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::task::yield_now().await; + } + + #[cfg(feature = "_rt-async-std")] + { + return async_std::task::yield_now().await; + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt(()) +} + +#[track_caller] +pub fn test_block_on(f: F) -> F::Output { + #[cfg(feature = "_rt-tokio")] + { + return tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to start Tokio runtime") + .block_on(f); + } + + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + { + return async_std::task::block_on(f); + } + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + { + drop(f); + panic!("at least one of the `runtime-*` features must be enabled") + } +} + +#[track_caller] +pub fn missing_rt(_unused: T) -> ! { + if cfg!(feature = "_rt-tokio") { + panic!("this functionality requires a Tokio context") + } + + panic!("at least one of the `runtime-*` features must be enabled") +} + +impl Future for JoinHandle { + type Output = T; + + #[track_caller] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match &mut *self { + #[cfg(feature = "_rt-async-std")] + Self::AsyncStd(handle) => Pin::new(handle).poll(cx), + #[cfg(feature = "_rt-tokio")] + Self::Tokio(handle) => Pin::new(handle) + .poll(cx) + .map(|res| res.expect("spawned task panicked")), + Self::_Phantom(_) => { + let _ = cx; + unreachable!("runtime should have been checked on spawn") + } + } + } +} diff --git a/sqlx-core/src/rt/rt_async_std/mod.rs b/sqlx-core/src/rt/rt_async_std/mod.rs new file mode 100644 index 0000000000..b6d40b922b --- /dev/null +++ b/sqlx-core/src/rt/rt_async_std/mod.rs @@ -0,0 +1 @@ +mod socket; diff --git a/sqlx-core/src/rt/rt_async_std/socket.rs b/sqlx-core/src/rt/rt_async_std/socket.rs new file mode 100644 index 0000000000..2d66d70c76 --- /dev/null +++ b/sqlx-core/src/rt/rt_async_std/socket.rs @@ -0,0 +1,55 @@ +use crate::net::Socket; + +use std::io; +use std::io::{Read, Write}; +use std::net::{Shutdown, TcpStream}; + +use std::task::{Context, Poll}; + +use crate::io::ReadBuf; +use async_io::Async; + +impl Socket for Async { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + self.get_mut().read(buf.init_mut()) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + self.get_mut().write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_readable(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_writable(cx) + } + + fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) + } +} + +#[cfg(unix)] +impl Socket for Async { + fn try_read(&mut self, buf: &mut dyn ReadBuf) -> io::Result { + self.get_mut().read(buf.init_mut()) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + self.get_mut().write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_readable(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_writable(cx) + } + + fn poll_shutdown(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().shutdown(Shutdown::Both)) + } +} diff --git a/sqlx-core/src/rt/rt_tokio/mod.rs b/sqlx-core/src/rt/rt_tokio/mod.rs new file mode 100644 index 0000000000..ce699456db --- /dev/null +++ b/sqlx-core/src/rt/rt_tokio/mod.rs @@ -0,0 +1,5 @@ +mod socket; + +pub fn available() -> bool { + tokio::runtime::Handle::try_current().is_ok() +} diff --git a/sqlx-core/src/rt/rt_tokio/socket.rs b/sqlx-core/src/rt/rt_tokio/socket.rs new file mode 100644 index 0000000000..bb57cbfde9 --- /dev/null +++ b/sqlx-core/src/rt/rt_tokio/socket.rs @@ -0,0 +1,55 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::AsyncWrite; +use tokio::net::TcpStream; + +use crate::io::ReadBuf; +use crate::net::Socket; + +impl Socket for TcpStream { + fn try_read(&mut self, mut buf: &mut dyn ReadBuf) -> io::Result { + // Requires `&mut impl BufMut` + self.try_read_buf(&mut buf) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + (*self).try_write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (*self).poll_read_ready(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (*self).poll_write_ready(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(self).poll_shutdown(cx) + } +} + +#[cfg(unix)] +impl Socket for tokio::net::UnixStream { + fn try_read(&mut self, mut buf: &mut dyn ReadBuf) -> io::Result { + self.try_read_buf(&mut buf) + } + + fn try_write(&mut self, buf: &[u8]) -> io::Result { + (*self).try_write(buf) + } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (*self).poll_read_ready(cx) + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + (*self).poll_write_ready(cx) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(self).poll_shutdown(cx) + } +} diff --git a/sqlx-core/src/sqlite/migrate.rs b/sqlx-core/src/sqlite/migrate.rs index c07f6c8164..4f3b862c28 100644 --- a/sqlx-core/src/sqlite/migrate.rs +++ b/sqlx-core/src/sqlite/migrate.rs @@ -1,6 +1,7 @@ use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; +use crate::fs; use crate::migrate::MigrateError; use crate::migrate::{AppliedMigration, Migration}; use crate::migrate::{Migrate, MigrateDatabase}; @@ -9,7 +10,6 @@ use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::sqlite::{Sqlite, SqliteConnectOptions, SqliteConnection, SqliteJournalMode}; use futures_core::future::BoxFuture; -use sqlx_rt::fs; use std::str::FromStr; use std::sync::atomic::Ordering; use std::time::Duration; diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index 7070ec4dda..680fdfa10b 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -43,12 +43,12 @@ use indexmap::IndexMap; /// /// # fn main() { /// # #[cfg(feature = "_rt-async-std")] -/// # sqlx_rt::async_std::task::block_on::<_, Result<(), Error>>(async move { +/// # sqlx::__rt::test_block_on(async move { /// let conn = SqliteConnectOptions::from_str("sqlite://data.db")? /// .journal_mode(SqliteJournalMode::Wal) /// .read_only(true) /// .connect().await?; -/// # Ok(()) +/// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// # } /// ``` diff --git a/sqlx-core/src/sqlite/testing/mod.rs b/sqlx-core/src/sqlite/testing/mod.rs index f3e48e6b7c..fb51eb0c7e 100644 --- a/sqlx-core/src/sqlite/testing/mod.rs +++ b/sqlx-core/src/sqlite/testing/mod.rs @@ -16,12 +16,12 @@ impl TestSupport for Sqlite { } fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { - Box::pin(async move { Ok(sqlx_rt::fs::remove_file(db_name).await?) }) + Box::pin(async move { Ok(crate::fs::remove_file(db_name).await?) }) } fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { Box::pin(async move { - sqlx_rt::fs::remove_dir_all(BASE_PATH).await?; + crate::fs::remove_dir_all(BASE_PATH).await?; Ok(None) }) } @@ -37,13 +37,13 @@ async fn test_context(args: &TestArgs) -> Result, Error> { let db_path = convert_path(args.test_path); if let Some(parent_path) = Path::parent(db_path.as_ref()) { - sqlx_rt::fs::create_dir_all(parent_path) + crate::fs::create_dir_all(parent_path) .await .expect("failed to create folders"); } if Path::exists(db_path.as_ref()) { - sqlx_rt::fs::remove_file(&db_path) + crate::fs::remove_file(&db_path) .await .expect("failed to remove database from previous test run"); } diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 1260fa46da..348f60034a 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -88,6 +88,7 @@ pub trait Statement<'q>: Send + Sync { A: IntoArguments<'s, Self::Database>; } +#[macro_export] macro_rules! impl_statement_query { ($A:ty) => { #[inline] diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs new file mode 100644 index 0000000000..bc6aae7f96 --- /dev/null +++ b/sqlx-core/src/sync.rs @@ -0,0 +1,145 @@ +// For types with identical signatures that don't require runtime support, +// we can just arbitrarily pick one to use based on what's enabled. +// +// We'll generally lean towards Tokio's types as those are more featureful +// (including `tokio-console` support) and more widely deployed. + +#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] +pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; + +#[cfg(feature = "_rt-tokio")] +pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; + +pub struct AsyncSemaphore { + // We use the semaphore from futures-intrusive as the one from async-std + // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: + // * https://github.com/smol-rs/async-lock/issues/22 + // * https://github.com/smol-rs/async-lock/issues/23 + // + // We're on the look-out for a replacement, however, as futures-intrusive is not maintained + // and there are some soundness concerns (although it turns out any intrusive future is unsound + // in MIRI due to the necessitated mutable aliasing): + // https://github.com/launchbadge/sqlx/issues/1668 + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + inner: futures_intrusive::sync::Semaphore, + + #[cfg(feature = "_rt-tokio")] + inner: tokio::sync::Semaphore, +} + +impl AsyncSemaphore { + #[track_caller] + pub fn new(fair: bool, permits: usize) -> Self { + if cfg!(not(any(feature = "_rt-async-std", feature = "_rt-tokio"))) { + crate::rt::missing_rt((fair, permits)); + } + + AsyncSemaphore { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + inner: futures_intrusive::sync::Semaphore::new(fair, permits), + #[cfg(feature = "_rt-tokio")] + inner: { + debug_assert!(fair, "Tokio only has fair permits"); + tokio::sync::Semaphore::new(permits) + }, + } + } + + pub fn permits(&self) -> usize { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + return self.inner.permits(); + + #[cfg(feature = "_rt-tokio")] + return self.inner.available_permits(); + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + crate::rt::missing_rt(()) + } + + pub async fn acquire(&self, permits: u32) -> AsyncSemaphoreReleaser<'_> { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + return AsyncSemaphoreReleaser { + inner: self.inner.acquire(permits as usize).await, + }; + + #[cfg(feature = "_rt-tokio")] + return AsyncSemaphoreReleaser { + inner: self + .inner + // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, + // but `u32` for this and `try_acquire_many()`. + .acquire_many(permits) + .await + .expect("BUG: we do not expose the `.close()` method"), + }; + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + crate::rt::missing_rt(permits) + } + + pub fn try_acquire(&self, permits: u32) -> Option> { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + return Some(AsyncSemaphoreReleaser { + inner: self.inner.try_acquire(permits as usize)?, + }); + + #[cfg(feature = "_rt-tokio")] + return Some(AsyncSemaphoreReleaser { + inner: self.inner.try_acquire_many(permits).ok()?, + }); + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + crate::rt::missing_rt(permits) + } + + pub fn release(&self, permits: usize) { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + return self.inner.release(permits); + + #[cfg(feature = "_rt-tokio")] + return self.inner.add_permits(permits); + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + crate::rt::missing_rt(permits) + } +} + +pub struct AsyncSemaphoreReleaser<'a> { + // We use the semaphore from futures-intrusive as the one from async-std + // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: + // * https://github.com/smol-rs/async-lock/issues/22 + // * https://github.com/smol-rs/async-lock/issues/23 + // + // We're on the look-out for a replacement, however, as futures-intrusive is not maintained + // and there are some soundness concerns (although it turns out any intrusive future is unsound + // in MIRI due to the necessitated mutable aliasing): + // https://github.com/launchbadge/sqlx/issues/1668 + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + inner: futures_intrusive::sync::SemaphoreReleaser<'a>, + + #[cfg(feature = "_rt-tokio")] + inner: tokio::sync::SemaphorePermit<'a>, + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + _phantom: std::marker::PhantomData<&'a ()>, +} + +impl AsyncSemaphoreReleaser<'_> { + pub fn disarm(self) { + #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] + { + let mut this = self; + this.inner.disarm(); + return; + } + + #[cfg(feature = "_rt-tokio")] + { + self.inner.forget(); + return; + } + + #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] + crate::rt::missing_rt(()) + } +} diff --git a/sqlx-core/src/testing/mod.rs b/sqlx-core/src/testing/mod.rs index 5183c914a9..1725f9eca1 100644 --- a/sqlx-core/src/testing/mod.rs +++ b/sqlx-core/src/testing/mod.rs @@ -4,7 +4,6 @@ use std::time::Duration; use futures_core::future::BoxFuture; pub use fixtures::FixtureSnapshot; -use sqlx_rt::test_block_on; use crate::connection::{ConnectOptions, Connection}; use crate::database::Database; @@ -135,7 +134,7 @@ where args.fixtures.is_empty(), "fixtures cannot be applied for a bare function" ); - test_block_on(self()) + crate::rt::test_block_on(self()) } } @@ -187,7 +186,7 @@ where let res = test_fn(pool.clone()).await; - let close_timed_out = sqlx_rt::timeout(Duration::from_secs(10), pool.close()) + let close_timed_out = crate::rt::timeout(Duration::from_secs(10), pool.close()) .await .is_err(); @@ -208,7 +207,7 @@ where Fut: Future, Fut::Output: TestTermination, { - test_block_on(async move { + crate::rt::test_block_on(async move { let test_context = DB::test_context(&args) .await .expect("failed to connect to setup test database"); diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index 44a540be8e..b69cdb283e 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -13,7 +13,6 @@ use quote::{format_ident, quote}; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; -use sqlx_rt::{block_on, AsyncMutex}; use crate::database::DatabaseExt; use crate::query::data::QueryData; @@ -135,7 +134,7 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { )))] Metadata { offline: false, - database_url: Some(db_url), + database_url: Some(_db_url), .. } => Err( "At least one of the features ['postgres', 'mysql', 'mssql', 'sqlite'] must be enabled \ @@ -177,6 +176,7 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { #[cfg(not(feature = "offline"))] Metadata { offline: true, .. } => { + drop(input); Err("The cargo feature `offline` has to be enabled to use `SQLX_OFFLINE`".into()) } @@ -185,7 +185,10 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { offline: false, database_url: None, .. - } => Err("`DATABASE_URL` must be set to use query macros".into()), + } => { + drop(input); + Err("`DATABASE_URL` must be set to use query macros".into()) + }, } } @@ -215,7 +218,10 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result>> = Lazy::new(|| AsyncMutex::new(BTreeMap::new())); @@ -249,7 +255,12 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result Result { std::env::var(name) } } + +#[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))] +use async_std::task::block_on; + +#[cfg(feature = "_rt-tokio")] +fn block_on(f: F) -> F::Output +where + F: std::future::Future, +{ + use tokio::runtime::{self, Runtime}; + + // We need a single, persistent Tokio runtime since we're caching connections, + // otherwise we'll get "IO driver has terminated" errors. + static TOKIO_RT: Lazy = Lazy::new(|| { + runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to start Tokio runtime") + }); + + TOKIO_RT.block_on(f) +} diff --git a/sqlx-rt/Cargo.toml b/sqlx-rt/Cargo.toml index 5db022a190..c3d58724a0 100644 --- a/sqlx-rt/Cargo.toml +++ b/sqlx-rt/Cargo.toml @@ -38,6 +38,8 @@ tokio-rustls = { version = "0.23.0", optional = true } native-tls = { version = "0.2.4", optional = true } once_cell = { version = "1.4", features = ["std"], optional = true } +futures-io = "0.3.21" + [dependencies.tokio] version = "1.0.1" features = ["fs", "net", "rt", "rt-multi-thread", "time", "io-util"] diff --git a/sqlx-rt/src/connect.rs b/sqlx-rt/src/connect.rs new file mode 100644 index 0000000000..fbbaae22a5 --- /dev/null +++ b/sqlx-rt/src/connect.rs @@ -0,0 +1,6 @@ +use std::future::Future; +use std::marker::PhantomData; + +use futures_io::{AsyncRead, AsyncWrite}; +use std::io; +use std::net::{SocketAddr, TcpStream}; diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index a0aac5b8ea..540bfbcec1 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -32,6 +32,8 @@ mod rt_async_std; #[cfg(any(feature = "_rt-tokio", feature = "_rt-actix"))] mod rt_tokio; +mod connect; + #[cfg(all(feature = "_tls-native-tls"))] pub use native_tls; diff --git a/src/lib.md b/src/lib.md new file mode 100644 index 0000000000..30a8b76d2c --- /dev/null +++ b/src/lib.md @@ -0,0 +1,62 @@ +The async SQL toolkit for Rust, built with ❤️ by [the LaunchBadge team]. + +See our [README] to get started or [browse our example projects]. +Have a question? [Check our FAQ] or [open a discussion]. + +### Runtime Support + +SQLx supports both the [Tokio] and [async-std] runtimes. + +You choose which runtime SQLx uses by default by enabling one of the following features: + +* `runtime-async-std` +* `runtime-tokio` + +The `runtime-actix` feature also exists but is an alias of `runtime-tokio`. + +If more than one runtime feature is enabled, the Tokio runtime is used if a Tokio context exists on the current +thread, i.e. [`tokio::runtime::Handle::try_current()`] returns `Ok`; `async-std` is used otherwise. + +Note that while SQLx no longer produces a compile error if zero or multiple runtime features are enabled, +which is useful for libraries building on top of it, +**the use of nearly any async function in the API will panic without at least one runtime feature enabled**. + +The chief exception is the SQLite driver, which is runtime-agnostic, including its integration with the query macros. +However, [`SqlitePool`][crate::sqlite::SqlitePool] _does_ require runtime support for timeouts and spawning +internal management tasks. + +### TLS Support + +For securely communicating with SQL servers over an untrusted network connection such as the internet, +you can enable Transport Layer Security (TLS) by enabling one of the following features: + +* `tls-native`: Enables the [`native-tls`] backend which uses the OS-native TLS capabilities: + * SecureTransport on macOS. + * SChannel on Windows. + * OpenSSL on all other platforms. +* `tls-rustls`: Enables the [RusTLS] backend, a crossplatform TLS library. + * Only supports TLS revisions 1.2 and 1.3. + * If you get `HandshakeFailure` errors when using this feature, it likely means your database server does not support + these newer revisions. This might be resolved by enabling or switching to the `tls-native` feature. + +If more than one TLS feature is enabled, the `tls-native` feature takes precedent so that it is only necessary to enable +it to see if it resolves the `HandshakeFailure` error without disabling `tls-rustls`. + +Consult the user manual for your database to find the TLS versions it supports. + +If your connection configuration requires a TLS upgrade but TLS support was not enabled, the connection attempt +will return an error. + +The legacy runtime+TLS combination feature flags are still supported, but for forward-compatibility, use of the separate +runtime and TLS feature flags is recommended. + +[the LaunchBadge team]: https://www.launchbadge.com +[README]: https://www.github.com/launchbadge/sqlx/tree/main/README.md +[browse our example projects]: https://www.github.com/launchbadge/sqlx/tree/main/examples +[Check our FAQ]: https://www.github.com/launchbadge/sqlx/tree/main/FAQ.md +[open a discussion]: https://github.com/launchbadge/sqlx/discussions/new?category=q-a +[Tokio]: https://www.tokio.rs +[async-std]: https://www.async.rs +[`tokio::runtime::Handle::try_current()`]: https://docs.rs/tokio/latest/tokio/runtime/struct.Handle.html#method.try_current +[`native-tls`]: https://docs.rs/native-tls/latest/native_tls/ +[RusTLS]: https://docs.rs/rustls/latest/rustls/ diff --git a/src/lib.rs b/src/lib.rs index e6487d1c10..708bdf78a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,5 @@ #![cfg_attr(docsrs, feature(doc_cfg))] - -#[cfg(any(feature = "runtime-async-std", feature = "runtime-tokio"))] -compile_error!( - "the features 'runtime-actix', 'runtime-async-std' and 'runtime-tokio' have been removed in - favor of new features 'runtime-{rt}-{tls}' where rt is one of 'actix', 'async-std' and 'tokio' - and 'tls' is one of 'native-tls' and 'rustls'." -); +#![doc = include_str!("lib.md")] pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; @@ -81,7 +75,7 @@ pub use sqlx_macros::test; pub use sqlx_core::testing; #[doc(hidden)] -pub use sqlx_core::test_block_on; +pub use sqlx_core::rt::test_block_on; #[cfg(feature = "macros")] mod macros; @@ -91,6 +85,9 @@ mod macros; #[doc(hidden)] pub mod ty_match; +#[doc(hidden)] +pub use sqlx_core::rt as __rt; + /// Conversions between Rust and SQL types. /// /// To see how each SQL type maps to a Rust type, see the corresponding `types` module for each diff --git a/src/macros/test.md b/src/macros/test.md index 05de0ffa39..13c241524e 100644 --- a/src/macros/test.md +++ b/src/macros/test.md @@ -1,11 +1,11 @@ Mark an `async fn` as a test with SQLx support. The test will automatically be executed in the async runtime according to the chosen -`runtime-{async-std, tokio}-{native-tls, rustls}` feature. +`runtime-{async-std, tokio}` feature. If more than one runtime feature is enabled, `runtime-tokio` is preferred. By default, this behaves identically to `#[tokio::test]`1 or `#[async_std::test]`: -```rust,norun +```rust # // Note if reading these examples directly in `test.md`: # // lines prefixed with `#` are not meant to be shown; # // they are supporting code to help the examples to compile successfully. diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 8fd73f9375..5bf29dab06 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -355,7 +355,7 @@ async fn test_issue_622() -> anyhow::Result<()> { for i in 0..3 { let pool = pool.clone(); - handles.push(sqlx_rt::spawn(async move { + handles.push(sqlx_core::rt::spawn(async move { { let mut conn = pool.acquire().await.unwrap(); @@ -366,7 +366,7 @@ async fn test_issue_622() -> anyhow::Result<()> { // (do some other work here without holding on to a connection) // this actually fixes the issue, depending on the timeout used - // sqlx_rt::sleep(Duration::from_millis(500)).await; + // sqlx_core::rt::sleep(Duration::from_millis(500)).await; { let start = Instant::now(); diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index d6f9cbac37..b545afce8e 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -553,7 +553,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { // spin up more tasks than connections available, and ensure we don't deadlock for i in 0..200 { let pool = pool.clone(); - sqlx_rt::spawn(async move { + sqlx_core::rt::spawn(async move { for j in 0.. { if let Err(e) = sqlx::query("select 1 + 1").execute(&pool).await { // normal error at termination of the test @@ -566,7 +566,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { } // shouldn't be necessary if the pool is fair - // sqlx_rt::yield_now().await; + // sqlx_core::rt::yield_now().await; } }); } @@ -575,7 +575,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> { // of cancellations for _ in 0..50 { let pool = pool.clone(); - sqlx_rt::spawn(async move { + sqlx_core::rt::spawn(async move { while !pool.is_closed() { let acquire = pool.acquire(); futures::pin_mut!(acquire); @@ -589,20 +589,20 @@ async fn pool_smoke_test() -> anyhow::Result<()> { // this one is necessary since this is a hot loop, // otherwise this task will never be descheduled - sqlx_rt::yield_now().await; + sqlx_core::rt::yield_now().await; } }); } eprintln!("sleeping for 30 seconds"); - sqlx_rt::sleep(Duration::from_secs(30)).await; + sqlx_core::rt::sleep(Duration::from_secs(30)).await; // assert_eq!(pool.size(), 10); eprintln!("closing pool"); - sqlx_rt::timeout(Duration::from_secs(30), pool.close()).await?; + sqlx_core::rt::timeout(Duration::from_secs(30), pool.close()).await?; eprintln!("pool closed successfully"); @@ -830,7 +830,7 @@ async fn test_issue_622() -> anyhow::Result<()> { for i in 0..3 { let pool = pool.clone(); - handles.push(sqlx_rt::spawn(async move { + handles.push(sqlx_core::rt::spawn(async move { { let mut conn = pool.acquire().await.unwrap(); @@ -841,7 +841,7 @@ async fn test_issue_622() -> anyhow::Result<()> { // (do some other work here without holding on to a connection) // this actually fixes the issue, depending on the timeout used - // sqlx_rt::sleep(Duration::from_millis(500)).await; + // sqlx_core::rt::sleep(Duration::from_millis(500)).await; { let start = Instant::now(); @@ -1008,7 +1008,7 @@ async fn test_pg_listener_allows_pool_to_close() -> anyhow::Result<()> { // acquires and holds a connection which would normally prevent the pool from closing let mut listener = PgListener::connect_with(&pool).await?; - sqlx_rt::spawn(async move { + sqlx_core::rt::spawn(async move { listener.recv().await.unwrap(); }); @@ -1671,7 +1671,7 @@ async fn test_advisory_locks() -> anyhow::Result<()> { // leak so we can take it across the task boundary let conn2_lock2 = lock2.acquire(conn2).await?.leak(); - sqlx_rt::spawn({ + sqlx_core::rt::spawn({ let lock1 = lock1.clone(); let lock2 = lock2.clone(); diff --git a/tests/sqlite/sqlcipher.rs b/tests/sqlite/sqlcipher.rs index 0a2a4499ea..335525fb0f 100644 --- a/tests/sqlite/sqlcipher.rs +++ b/tests/sqlite/sqlcipher.rs @@ -3,16 +3,12 @@ use std::str::FromStr; use sqlx::sqlite::SqliteQueryResult; use sqlx::{query, Connection, SqliteConnection}; use sqlx::{sqlite::SqliteConnectOptions, ConnectOptions}; -use sqlx_rt::fs::File; use tempdir::TempDir; async fn new_db_url() -> anyhow::Result<(String, TempDir)> { let dir = TempDir::new("sqlcipher_test")?; let filepath = dir.path().join("database.sqlite3"); - // Touch the file, so DB driver will not complain it does not exist - File::create(filepath.as_path()).await?; - Ok((format!("sqlite://{}", filepath.display()), dir)) } @@ -53,6 +49,7 @@ async fn it_encrypts() -> anyhow::Result<()> { let mut conn = SqliteConnectOptions::from_str(&url)? .pragma("key", "the_password") + .create_if_missing(true) .connect() .await?; @@ -77,6 +74,7 @@ async fn it_can_store_and_read_encrypted_data() -> anyhow::Result<()> { let mut conn = SqliteConnectOptions::from_str(&url)? .pragma("key", "the_password") + .create_if_missing(true) .connect() .await?; @@ -105,6 +103,7 @@ async fn it_fails_if_password_is_incorrect() -> anyhow::Result<()> { let mut conn = SqliteConnectOptions::from_str(&url)? .pragma("key", "the_password") + .create_if_missing(true) .connect() .await?; @@ -142,6 +141,7 @@ async fn it_honors_order_of_encryption_pragmas() -> anyhow::Result<()> { .pragma("kdf_iter", "64000") .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental) .pragma("cipher_hmac_algorithm", "HMAC_SHA1") + .create_if_missing(true) .connect() .await?; @@ -174,6 +174,7 @@ async fn it_allows_to_rekey_the_db() -> anyhow::Result<()> { let mut conn = SqliteConnectOptions::from_str(&url)? .pragma("key", "the_password") + .create_if_missing(true) .connect() .await?; diff --git a/tests/sqlite/sqlite.db b/tests/sqlite/sqlite.db index a3d8d5cc2964ec37bc568d4e00340395edd8a5d2..6e0930ff769997874fb2e9e692aef752a39cdb0b 100644 GIT binary patch delta 76 zcmZozz|^pSX#