diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 9c03f85788..a4ba260ddf 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -17,14 +17,14 @@ pub async fn create(connect_opts: &ConnectOpts) -> anyhow::Result<()> { std::sync::atomic::Ordering::Release, ); - Any::create_database(&connect_opts.database_url).await?; + Any::create_database(connect_opts.required_db_url()?).await?; } Ok(()) } pub async fn drop(connect_opts: &ConnectOpts, confirm: bool) -> anyhow::Result<()> { - if confirm && !ask_to_continue(connect_opts) { + if confirm && !ask_to_continue_drop(connect_opts.required_db_url()?) { return Ok(()); } @@ -33,7 +33,7 @@ pub async fn drop(connect_opts: &ConnectOpts, confirm: bool) -> anyhow::Result<( let exists = crate::retry_connect_errors(connect_opts, Any::database_exists).await?; if exists { - Any::drop_database(&connect_opts.database_url).await?; + Any::drop_database(connect_opts.required_db_url()?).await?; } Ok(()) @@ -53,12 +53,10 @@ pub async fn setup(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow migrate::run(migration_source, connect_opts, false, false, None).await } -fn ask_to_continue(connect_opts: &ConnectOpts) -> bool { +fn ask_to_continue_drop(db_url: &str) -> bool { loop { - let r: Result = prompt(format!( - "Drop database at {}? (y/n)", - style(&connect_opts.database_url).cyan() - )); + let r: Result = + prompt(format!("Drop database at {}? (y/n)", style(db_url).cyan())); match r { Ok(response) => { if response == "n" || response == "N" { diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index aec2392f75..9fd2ff60be 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -101,7 +101,7 @@ pub async fn run(opt: Opt) -> Result<()> { } /// Attempt to connect to the database server, retrying up to `ops.connect_timeout`. -async fn connect(opts: &ConnectOpts) -> sqlx::Result { +async fn connect(opts: &ConnectOpts) -> anyhow::Result { retry_connect_errors(opts, AnyConnection::connect).await } @@ -112,32 +112,34 @@ async fn connect(opts: &ConnectOpts) -> sqlx::Result { async fn retry_connect_errors<'a, F, Fut, T>( opts: &'a ConnectOpts, mut connect: F, -) -> sqlx::Result +) -> anyhow::Result where F: FnMut(&'a str) -> Fut, Fut: Future> + 'a, { sqlx::any::install_default_drivers(); + let db_url = opts.required_db_url()?; + backoff::future::retry( backoff::ExponentialBackoffBuilder::new() .with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout))) .build(), || { - connect(&opts.database_url).map_err(|e| -> backoff::Error { + connect(db_url).map_err(|e| -> backoff::Error { match e { sqlx::Error::Io(ref ioe) => match ioe.kind() { io::ErrorKind::ConnectionRefused | io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted => { - return backoff::Error::transient(e); + return backoff::Error::transient(e.into()); } _ => (), }, _ => (), } - backoff::Error::permanent(e) + backoff::Error::permanent(e.into()) }) }, ) diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index d2dc0732ef..9f5630dd24 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -229,9 +229,9 @@ impl Deref for Source { /// Argument for the database URL. #[derive(Args, Debug)] pub struct ConnectOpts { - /// Location of the DB, by default will be read from the DATABASE_URL env var + /// Location of the DB, by default will be read from the DATABASE_URL env var or `.env` files. #[clap(long, short = 'D', env)] - pub database_url: String, + pub database_url: Option, /// The maximum time, in seconds, to try connecting to the database server before /// returning an error. @@ -251,6 +251,18 @@ pub struct ConnectOpts { pub sqlite_create_db_wal: bool, } +impl ConnectOpts { + /// Require a database URL to be provided, otherwise + /// return an error. + pub fn required_db_url(&self) -> anyhow::Result<&str> { + self.database_url.as_deref().ok_or_else( + || anyhow::anyhow!( + "the `--database-url` option the or `DATABASE_URL` environment variable must be provided" + ) + ) + } +} + /// Argument for automatic confirmation. #[derive(Args, Copy, Clone, Debug)] pub struct Confirmation { diff --git a/sqlx-cli/src/prepare.rs b/sqlx-cli/src/prepare.rs index 7e35844a32..18987048f9 100644 --- a/sqlx-cli/src/prepare.rs +++ b/sqlx-cli/src/prepare.rs @@ -7,7 +7,6 @@ use std::process::Command; use anyhow::{bail, Context}; use console::style; - use sqlx::Connection; use crate::metadata::{manifest_dir, Metadata}; @@ -64,7 +63,9 @@ hint: This command only works in the manifest directory of a Cargo package or wo } async fn prepare(ctx: &PrepareCtx) -> anyhow::Result<()> { - check_backend(&ctx.connect_opts).await?; + if ctx.connect_opts.database_url.is_some() { + check_backend(&ctx.connect_opts).await?; + } let prepare_dir = ctx.prepare_dir()?; run_prepare_step(ctx, &prepare_dir)?; @@ -90,7 +91,9 @@ async fn prepare(ctx: &PrepareCtx) -> anyhow::Result<()> { } async fn prepare_check(ctx: &PrepareCtx) -> anyhow::Result<()> { - let _ = check_backend(&ctx.connect_opts).await?; + if ctx.connect_opts.database_url.is_some() { + check_backend(&ctx.connect_opts).await?; + } // Re-generate and store the queries in a separate directory from both the prepared // queries and the ones generated by `cargo check`, to avoid conflicts. @@ -171,10 +174,14 @@ fn run_prepare_step(ctx: &PrepareCtx, cache_dir: &Path) -> anyhow::Result<()> { check_command .arg("check") .args(&ctx.cargo_args) - .env("DATABASE_URL", &ctx.connect_opts.database_url) + .env("SQLX_TMP", tmp_dir) .env("SQLX_OFFLINE", "false") .env("SQLX_OFFLINE_DIR", cache_dir); + if let Some(database_url) = &ctx.connect_opts.database_url { + check_command.env("DATABASE_URL", database_url); + } + // `cargo check` recompiles on changed rust flags which can be set either via the env var // or through the `rustflags` field in `$CARGO_HOME/config` when the env var isn't set. // Because of this we only pass in `$RUSTFLAGS` when present. @@ -319,12 +326,6 @@ fn minimal_project_recompile_action(metadata: &Metadata) -> ProjectRecompileActi } } -/// Ensure the database server is available. -async fn check_backend(opts: &ConnectOpts) -> anyhow::Result<()> { - crate::connect(opts).await?.close().await?; - Ok(()) -} - /// Find all `query-*.json` files in a directory. fn glob_query_files(path: impl AsRef) -> anyhow::Result> { let path = path.as_ref(); @@ -347,6 +348,11 @@ fn load_json_file(path: impl AsRef) -> anyhow::Result { Ok(serde_json::from_slice(&file_bytes)?) } +async fn check_backend(opts: &ConnectOpts) -> anyhow::Result<()> { + crate::connect(opts).await?.close().await?; + Ok(()) +} + #[cfg(test)] mod tests { use super::*;