From adf89267e44874022fc20ab7f8038d131d7578b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oddbj=C3=B8rn=20Gr=C3=B8dem?= <29732646+oddgrd@users.noreply.github.com> Date: Fri, 10 Mar 2023 16:57:14 +0000 Subject: [PATCH] ci: resolve CI errors in shuttle-next (#580) * test: compile wasm module in axum runtime test setup * ci: add next patch override to CI * ci: include wasm32-wasi target in rust install * fix: deployer tests where runtime fails to start * fix: incorrect provisioner address * feat: log service state changes in runtime * feat: don't send stop req on startup failure * refactor: unused imports * refactor: handling legacy panics * tests: deadlock less * refactor: fixups * refactor: clippy suggestions * tests: mock provisioner * refactor: restore capture from 'log' and colors * refactor: clippy suggestions * tests: longer wait * tests: don't panic while holding lock * tests: don't panic on stream closed * tests: don't filter out state logs * tests: bigger timeout * ci: remove duplicate patch * refactor: comments --------- Co-authored-by: chesedo --- admin/Cargo.toml | 4 +- cargo-shuttle/Cargo.toml | 4 +- cargo-shuttle/src/lib.rs | 2 +- deployer/Cargo.toml | 4 +- deployer/src/deployment/deploy_layer.rs | 444 ++++++++++++------------ deployer/src/deployment/mod.rs | 18 +- deployer/src/deployment/queue.rs | 2 +- deployer/src/deployment/run.rs | 282 +++++++++------ deployer/src/error.rs | 2 + deployer/src/runtime_manager.rs | 53 ++- gateway/Cargo.toml | 4 +- proto/runtime.proto | 24 ++ provisioner/Cargo.toml | 4 +- runtime/Cargo.toml | 4 +- runtime/src/legacy/mod.rs | 65 +++- runtime/src/next/mod.rs | 17 +- 16 files changed, 549 insertions(+), 384 deletions(-) diff --git a/admin/Cargo.toml b/admin/Cargo.toml index 00bcefa0e..d2fda7879 100644 --- a/admin/Cargo.toml +++ b/admin/Cargo.toml @@ -12,8 +12,8 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { version = "1.22.0", features = ["macros", "rt-multi-thread"] } toml = "0.5.9" -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = ["env-filter"] } +tracing = { workspace = true, features = ["default"] } +tracing-subscriber = { workspace = true, features = ["default", "env-filter"] } [dependencies.shuttle-common] workspace = true diff --git a/cargo-shuttle/Cargo.toml b/cargo-shuttle/Cargo.toml index 475c5d900..ad2036b91 100644 --- a/cargo-shuttle/Cargo.toml +++ b/cargo-shuttle/Cargo.toml @@ -45,8 +45,8 @@ tokio-tungstenite = { version = "0.17.2", features = ["native-tls"] } toml = "0.5.9" toml_edit = "0.15.0" tonic = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } +tracing = { workspace = true, features = ["default"] } +tracing-subscriber = { workspace = true, features = ["default", "env-filter", "fmt"] } url = "2.3.1" uuid = { workspace = true, features = ["v4"] } webbrowser = "0.8.2" diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 2c2466269..d487d72ab 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -510,7 +510,7 @@ impl Shuttle { .into_inner(); tokio::spawn(async move { - while let Some(log) = stream.message().await.expect("to get log from stream") { + while let Ok(Some(log)) = stream.message().await { let log: shuttle_common::LogItem = log.try_into().expect("to convert log"); println!("{log}"); } diff --git a/deployer/Cargo.toml b/deployer/Cargo.toml index 73edc850f..ef0527920 100644 --- a/deployer/Cargo.toml +++ b/deployer/Cargo.toml @@ -40,9 +40,9 @@ toml = "0.5.9" tonic = { workspace = true } tower = { version = "0.4.13", features = ["make"] } tower-http = { version = "0.3.4", features = ["auth", "trace"] } -tracing = { workspace = true } +tracing = { workspace = true, features = ["default"] } tracing-opentelemetry = "0.18.0" -tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } +tracing-subscriber = { workspace = true, features = ["default", "env-filter", "fmt"] } uuid = { workspace = true, features = ["v4"] } [dependencies.shuttle-common] diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index f7d35873b..94e8be3f0 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -322,19 +322,26 @@ impl Visit for NewStateVisitor { mod tests { use std::{ fs::read_dir, - net::SocketAddr, + net::{Ipv4Addr, SocketAddr}, path::PathBuf, sync::{Arc, Mutex}, time::Duration, }; use crate::{persistence::DeploymentUpdater, RuntimeManager}; + use async_trait::async_trait; use axum::body::Bytes; use ctor::ctor; use flate2::{write::GzEncoder, Compression}; + use portpicker::pick_unused_port; + use shuttle_proto::provisioner::{ + provisioner_server::{Provisioner, ProvisionerServer}, + DatabaseRequest, DatabaseResponse, + }; use tempdir::TempDir; use tokio::{select, time::sleep}; - use tracing_subscriber::prelude::*; + use tonic::transport::Server; + use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use uuid::Uuid; use crate::{ @@ -350,8 +357,46 @@ mod tests { #[ctor] static RECORDER: Arc> = { let recorder = RecorderMock::new(); + + // Copied from the test-log crate + let event_filter = { + use ::tracing_subscriber::fmt::format::FmtSpan; + + match ::std::env::var("RUST_LOG_SPAN_EVENTS") { + Ok(value) => { + value + .to_ascii_lowercase() + .split(',') + .map(|filter| match filter.trim() { + "new" => FmtSpan::NEW, + "enter" => FmtSpan::ENTER, + "exit" => FmtSpan::EXIT, + "close" => FmtSpan::CLOSE, + "active" => FmtSpan::ACTIVE, + "full" => FmtSpan::FULL, + _ => panic!("test-log: RUST_LOG_SPAN_EVENTS must contain filters separated by `,`.\n\t\ + For example: `active` or `new,close`\n\t\ + Supported filters: new, enter, exit, close, active, full\n\t\ + Got: {}", value), + }) + .fold(FmtSpan::NONE, |acc, filter| filter | acc) + }, + Err(::std::env::VarError::NotUnicode(_)) => + panic!("test-log: RUST_LOG_SPAN_EVENTS must contain a valid UTF-8 string"), + Err(::std::env::VarError::NotPresent) => FmtSpan::NONE, + } + }; + let fmt_layer = fmt::layer() + .with_test_writer() + .with_span_events(event_filter); + let filter_layer = EnvFilter::try_from_default_env() + .or_else(|_| EnvFilter::try_new("shuttle_deployer")) + .unwrap(); + tracing_subscriber::registry() .with(DeployLayer::new(Arc::clone(&recorder))) + .with(filter_layer) + .with(fmt_layer) .init(); recorder @@ -404,12 +449,36 @@ mod tests { } } + struct ProvisionerMock; + + #[async_trait] + impl Provisioner for ProvisionerMock { + async fn provision_database( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + panic!("no deploy layer tests should request a db"); + } + } + fn get_runtime_manager() -> Arc> { + let provisioner_addr = + SocketAddr::new(Ipv4Addr::LOCALHOST.into(), pick_unused_port().unwrap()); + let mock = ProvisionerMock; + + tokio::spawn(async move { + Server::builder() + .add_service(ProvisionerServer::new(mock)) + .serve(provisioner_addr) + .await + .unwrap(); + }); + let tmp_dir = TempDir::new("shuttle_run_test").unwrap(); let path = tmp_dir.into_path(); let (tx, _rx) = crossbeam_channel::unbounded(); - RuntimeManager::new(path, "http://provisioner:8000".to_string(), tx) + RuntimeManager::new(path, format!("http://{}", provisioner_addr), tx) } #[async_trait::async_trait] @@ -495,6 +564,18 @@ mod tests { } } + async fn test_states(id: &Uuid, expected_states: Vec) { + loop { + let states = RECORDER.lock().unwrap().get_deployment_states(id); + + if *states == expected_states { + break; + } + + sleep(Duration::from_millis(250)).await; + } + } + #[tokio::test(flavor = "multi_thread")] async fn deployment_to_be_queued() { let deployment_manager = get_deployment_manager().await; @@ -503,67 +584,46 @@ mod tests { let id = queued.id; deployment_manager.queue_push(queued).await; - let test = async { - loop { - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); - - if states.len() < 5 { - drop(recorder); // Don't block - sleep(Duration::from_millis(350)).await; - continue; - } - - assert_eq!( - states.len(), - 5, - "did not expect these states:\n\t{states:#?}" - ); - - assert_eq!( - *states, - vec![ - StateLog { - id, - state: State::Queued, - }, - StateLog { - id, - state: State::Building, - }, - StateLog { - id, - state: State::Built, - }, - StateLog { - id, - state: State::Loading, - }, - StateLog { - id, - state: State::Running, - }, - ] - ); - - break; - } - }; + let test = test_states( + &id, + vec![ + StateLog { + id, + state: State::Queued, + }, + StateLog { + id, + state: State::Building, + }, + StateLog { + id, + state: State::Built, + }, + StateLog { + id, + state: State::Loading, + }, + StateLog { + id, + state: State::Running, + }, + ], + ); select! { - _ = sleep(Duration::from_secs(180)) => { - panic!("states should go into 'Running' for a valid service"); - } + _ = sleep(Duration::from_secs(240)) => { + let states = RECORDER.lock().unwrap().get_deployment_states(&id); + panic!("states should go into 'Running' for a valid service: {:#?}", states); + }, _ = test => {} - } + }; // Send kill signal deployment_manager.kill(id).await; sleep(Duration::from_secs(1)).await; - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); + let states = RECORDER.lock().unwrap().get_deployment_states(&id); assert_eq!( *states, @@ -604,60 +664,40 @@ mod tests { let id = queued.id; deployment_manager.queue_push(queued).await; - let test = async { - loop { - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); - - if states.len() < 6 { - drop(recorder); // Don't block - sleep(Duration::from_millis(350)).await; - continue; - } - - assert_eq!( - states.len(), - 6, - "did not expect these states:\n\t{states:#?}" - ); - - assert_eq!( - *states, - vec![ - StateLog { - id, - state: State::Queued, - }, - StateLog { - id, - state: State::Building, - }, - StateLog { - id, - state: State::Built, - }, - StateLog { - id, - state: State::Loading, - }, - StateLog { - id, - state: State::Running, - }, - StateLog { - id, - state: State::Completed, - }, - ] - ); - - break; - } - }; + let test = test_states( + &id, + vec![ + StateLog { + id, + state: State::Queued, + }, + StateLog { + id, + state: State::Building, + }, + StateLog { + id, + state: State::Built, + }, + StateLog { + id, + state: State::Loading, + }, + StateLog { + id, + state: State::Running, + }, + StateLog { + id, + state: State::Completed, + }, + ], + ); select! { - _ = sleep(Duration::from_secs(180)) => { - panic!("states should go into 'Completed' when a service stops by itself"); + _ = sleep(Duration::from_secs(240)) => { + let states = RECORDER.lock().unwrap().get_deployment_states(&id); + panic!("states should go into 'Completed' when a service stops by itself: {:#?}", states); } _ = test => {} } @@ -671,60 +711,40 @@ mod tests { let id = queued.id; deployment_manager.queue_push(queued).await; - let test = async { - loop { - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); - - if states.len() < 6 { - drop(recorder); // Don't block - sleep(Duration::from_millis(350)).await; - continue; - } - - assert_eq!( - states.len(), - 6, - "did not expect these states:\n\t{states:#?}" - ); - - assert_eq!( - *states, - vec![ - StateLog { - id, - state: State::Queued, - }, - StateLog { - id, - state: State::Building, - }, - StateLog { - id, - state: State::Built, - }, - StateLog { - id, - state: State::Loading, - }, - StateLog { - id, - state: State::Running, - }, - StateLog { - id, - state: State::Crashed, - }, - ] - ); - - break; - } - }; + let test = test_states( + &id, + vec![ + StateLog { + id, + state: State::Queued, + }, + StateLog { + id, + state: State::Building, + }, + StateLog { + id, + state: State::Built, + }, + StateLog { + id, + state: State::Loading, + }, + StateLog { + id, + state: State::Running, + }, + StateLog { + id, + state: State::Crashed, + }, + ], + ); select! { - _ = sleep(Duration::from_secs(180)) => { - panic!("states should go into 'Crashed' panicing in bind"); + _ = sleep(Duration::from_secs(240)) => { + let states = RECORDER.lock().unwrap().get_deployment_states(&id); + panic!("states should go into 'Crashed' panicing in bind: {:#?}", states); } _ = test => {} } @@ -738,56 +758,40 @@ mod tests { let id = queued.id; deployment_manager.queue_push(queued).await; - let test = async { - loop { - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); - - if states.len() < 5 { - drop(recorder); // Don't block - sleep(Duration::from_millis(350)).await; - continue; - } - - assert_eq!( - states.len(), - 5, - "did not expect these states:\n\t{states:#?}" - ); - - assert_eq!( - *states, - vec![ - StateLog { - id, - state: State::Queued, - }, - StateLog { - id, - state: State::Building, - }, - StateLog { - id, - state: State::Built, - }, - StateLog { - id, - state: State::Loading, - }, - StateLog { - id, - state: State::Crashed, - }, - ] - ); - - break; - } - }; + let test = test_states( + &id, + vec![ + StateLog { + id, + state: State::Queued, + }, + StateLog { + id, + state: State::Building, + }, + StateLog { + id, + state: State::Built, + }, + StateLog { + id, + state: State::Loading, + }, + StateLog { + id, + state: State::Running, + }, + StateLog { + id, + state: State::Crashed, + }, + ], + ); select! { - _ = sleep(Duration::from_secs(180)) => { - panic!("states should go into 'Crashed' when panicing in main"); + _ = sleep(Duration::from_secs(240)) => { + let states = RECORDER.lock().unwrap().get_deployment_states(&id); + panic!("states should go into 'Crashed' when panicing in main: {:#?}", states); } _ = test => {} } @@ -808,20 +812,8 @@ mod tests { }) .await; - // Give it a small time to start up - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - let recorder = RECORDER.lock().unwrap(); - let states = recorder.get_deployment_states(&id); - - assert_eq!( - states.len(), - 3, - "did not expect these states:\n\t{states:#?}" - ); - - assert_eq!( - *states, + let test = test_states( + &id, vec![ StateLog { id, @@ -835,8 +827,16 @@ mod tests { id, state: State::Crashed, }, - ] + ], ); + + select! { + _ = sleep(Duration::from_secs(50)) => { + let states = RECORDER.lock().unwrap().get_deployment_states(&id); + panic!("from running should start in built and end in crash for invalid: {:#?}", states) + }, + _ = test => {} + }; } #[tokio::test] diff --git a/deployer/src/deployment/mod.rs b/deployer/src/deployment/mod.rs index 121a4c926..09ceeaba2 100644 --- a/deployer/src/deployment/mod.rs +++ b/deployer/src/deployment/mod.rs @@ -15,14 +15,13 @@ use crate::{ persistence::{DeploymentUpdater, SecretGetter, SecretRecorder, State}, RuntimeManager, }; -use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::sync::{mpsc, Mutex}; use uuid::Uuid; use self::{deploy_layer::LogRecorder, gateway_client::BuildQueueClient}; const QUEUE_BUFFER_SIZE: usize = 100; const RUN_BUFFER_SIZE: usize = 100; -const KILL_BUFFER_SIZE: usize = 10; pub struct DeploymentManagerBuilder { build_log_recorder: Option, @@ -114,7 +113,6 @@ where let (queue_send, queue_recv) = mpsc::channel(QUEUE_BUFFER_SIZE); let (run_send, run_recv) = mpsc::channel(RUN_BUFFER_SIZE); - let (kill_send, _) = broadcast::channel(KILL_BUFFER_SIZE); let storage_manager = ArtifactsStorageManager::new(artifacts_path); let run_send_clone = run_send.clone(); @@ -130,9 +128,8 @@ where )); tokio::spawn(run::task( run_recv, - runtime_manager, + runtime_manager.clone(), deployment_updater, - kill_send.clone(), active_deployment_getter, secret_getter, storage_manager.clone(), @@ -141,7 +138,7 @@ where DeploymentManager { queue_send, run_send, - kill_send, + runtime_manager, storage_manager, } } @@ -151,7 +148,7 @@ where pub struct DeploymentManager { queue_send: QueueSender, run_send: RunSender, - kill_send: KillSender, + runtime_manager: Arc>, storage_manager: ArtifactsStorageManager, } @@ -201,9 +198,7 @@ impl DeploymentManager { } pub async fn kill(&self, id: Uuid) { - if self.kill_send.receiver_count() > 0 { - self.kill_send.send(id).unwrap(); - } + self.runtime_manager.lock().await.kill(&id).await; } pub fn storage_manager(&self) -> ArtifactsStorageManager { @@ -216,6 +211,3 @@ type QueueReceiver = mpsc::Receiver; type RunSender = mpsc::Sender; type RunReceiver = mpsc::Receiver; - -type KillSender = broadcast::Sender; -type KillReceiver = broadcast::Receiver; diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index 0d454b4bb..6bf0675f2 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -128,7 +128,7 @@ async fn remove_from_queue(queue_client: impl BuildQueueClient, id: Uuid) { } } -#[instrument(fields(id = %built.id, state = %State::Built))] +#[instrument(skip(run_send), fields(id = %built.id, state = %State::Built))] async fn promote_to_run(mut built: Built, run_send: RunSender) { let cx = Span::current().context(); diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index c1123005a..e536ab0ea 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -10,16 +10,17 @@ use opentelemetry::global; use portpicker::pick_unused_port; use shuttle_common::storage_manager::ArtifactsStorageManager; use shuttle_proto::runtime::{ - runtime_client::RuntimeClient, LoadRequest, StartRequest, StopRequest, StopResponse, + runtime_client::RuntimeClient, LoadRequest, StartRequest, StopReason, SubscribeStopRequest, + SubscribeStopResponse, }; use tokio::sync::Mutex; -use tonic::{transport::Channel, Response, Status}; -use tracing::{debug, debug_span, error, info, instrument, trace, Instrument}; +use tonic::{transport::Channel, Code}; +use tracing::{debug, debug_span, error, info, instrument, trace, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; use uuid::Uuid; -use super::{KillReceiver, KillSender, RunReceiver, State}; +use super::{RunReceiver, State}; use crate::{ error::{Error, Result}, persistence::{DeploymentUpdater, SecretGetter}, @@ -32,7 +33,6 @@ pub async fn task( mut recv: RunReceiver, runtime_manager: Arc>, deployment_updater: impl DeploymentUpdater, - kill_send: KillSender, active_deployment_getter: impl ActiveDeploymentsGetter, secret_getter: impl SecretGetter, storage_manager: ArtifactsStorageManager, @@ -45,8 +45,6 @@ pub async fn task( info!("Built deployment at the front of run queue: {id}"); let deployment_updater = deployment_updater.clone(); - let kill_send = kill_send.clone(); - let kill_recv = kill_send.subscribe(); let secret_getter = secret_getter.clone(); let storage_manager = storage_manager.clone(); @@ -54,14 +52,17 @@ pub async fn task( built.service_id, id, active_deployment_getter.clone(), - kill_send, + runtime_manager.clone(), ); - let cleanup = move |result: std::result::Result, Status>| { - info!(response = ?result, "stop client response: "); - - match result { - Ok(_) => completed_cleanup(&id), - Err(err) => crashed_cleanup(&id, err), + let cleanup = move |response: SubscribeStopResponse| { + debug!(response = ?response, "stop client response: "); + + match StopReason::from_i32(response.reason).unwrap_or_default() { + StopReason::Request => stopped_cleanup(&id), + StopReason::End => completed_cleanup(&id), + StopReason::Crash => { + crashed_cleanup(&id, Error::Run(anyhow::Error::msg(response.message).into())) + } } }; let runtime_manager = runtime_manager.clone(); @@ -80,7 +81,6 @@ pub async fn task( secret_getter, runtime_manager, deployment_updater, - kill_recv, old_deployments_killer, cleanup, ) @@ -97,13 +97,15 @@ pub async fn task( } } -#[instrument(skip(active_deployment_getter, kill_send))] +#[instrument(skip(active_deployment_getter, runtime_manager))] async fn kill_old_deployments( service_id: Uuid, deployment_id: Uuid, active_deployment_getter: impl ActiveDeploymentsGetter, - kill_send: KillSender, + runtime_manager: Arc>, ) -> Result<()> { + let mut guard = runtime_manager.lock().await; + for old_id in active_deployment_getter .clone() .get_active_deployments(&service_id) @@ -113,9 +115,10 @@ async fn kill_old_deployments( .filter(|old_id| old_id != &deployment_id) { trace!(%old_id, "stopping old deployment"); - kill_send - .send(old_id) - .map_err(|e| Error::OldCleanup(Box::new(e)))?; + + if !guard.kill(&old_id).await { + warn!(id = %old_id, "failed to kill old deployment"); + } } Ok(()) @@ -167,7 +170,7 @@ pub struct Built { } impl Built { - #[instrument(skip(self, storage_manager, secret_getter, runtime_manager, deployment_updater, kill_recv, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] + #[instrument(skip(self, storage_manager, secret_getter, runtime_manager, deployment_updater, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] #[allow(clippy::too_many_arguments)] async fn handle( self, @@ -175,9 +178,8 @@ impl Built { secret_getter: impl SecretGetter, runtime_manager: Arc>, deployment_updater: impl DeploymentUpdater, - kill_recv: KillReceiver, kill_old_deployments: impl futures::Future>, - cleanup: impl FnOnce(std::result::Result, Status>) + Send + 'static, + cleanup: impl FnOnce(SubscribeStopResponse) + Send + 'static, ) -> Result<()> { let so_path = storage_manager.deployment_library_path(&self.id)?; @@ -191,8 +193,9 @@ impl Built { }; let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); - let mut runtime_manager = runtime_manager.lock().await.clone(); let runtime_client = runtime_manager + .lock() + .await .get_runtime_client(self.is_next) .await .map_err(Error::Runtime)?; @@ -206,26 +209,18 @@ impl Built { self.service_id, so_path, secret_getter, - runtime_client, + runtime_client.clone(), ) .await?; - // Move runtime manager to this thread so that the runtime lives long enough - tokio::spawn(async move { - let runtime_client = runtime_manager - .get_runtime_client(self.is_next) - .await - .unwrap(); - run( - self.id, - runtime_client, - address, - deployment_updater, - kill_recv, - cleanup, - ) - .await - }); + tokio::spawn(run( + self.id, + self.service_name, + runtime_client, + address, + deployment_updater, + cleanup, + )); Ok(()) } @@ -236,7 +231,7 @@ async fn load( service_id: Uuid, so_path: PathBuf, secret_getter: impl SecretGetter, - runtime_client: &mut RuntimeClient, + mut runtime_client: RuntimeClient, ) -> Result<()> { info!( "loading project from: {}", @@ -276,14 +271,14 @@ async fn load( } } -#[instrument(skip(runtime_client, deployment_updater, kill_recv, cleanup), fields(state = %State::Running))] +#[instrument(skip(runtime_client, deployment_updater, cleanup), fields(state = %State::Running))] async fn run( id: Uuid, - runtime_client: &mut RuntimeClient, + service_name: String, + mut runtime_client: RuntimeClient, address: SocketAddr, deployment_updater: impl DeploymentUpdater, - mut kill_recv: KillReceiver, - cleanup: impl FnOnce(std::result::Result, Status>) + Send + 'static, + cleanup: impl FnOnce(SubscribeStopResponse) + Send + 'static, ) { deployment_updater .set_address(&id, &address) @@ -295,41 +290,68 @@ async fn run( ip: address.to_string(), }); - info!("starting service"); - let response = runtime_client - .start(start_request) + // Subscribe to stop before starting to catch immediate errors + let mut stream = runtime_client + .subscribe_stop(tonic::Request::new(SubscribeStopRequest {})) .await - .expect("to start deployment"); + .unwrap() + .into_inner(); - info!(response = ?response.into_inner(), "start client response: "); + info!("starting service"); + let response = runtime_client.start(start_request).await; - let mut response = Err(Status::unknown("not stopped yet")); + match response { + Ok(response) => { + info!(response = ?response.into_inner(), "start client response: "); - while let Ok(kill_id) = kill_recv.recv().await { - if kill_id == id { - let stop_request = tonic::Request::new(StopRequest {}); - response = runtime_client.stop(stop_request).await; + // Wait for stop reason + let reason = stream.message().await.unwrap().unwrap(); - break; + cleanup(reason); } - } + Err(ref status) if status.code() == Code::InvalidArgument => { + cleanup(SubscribeStopResponse { + reason: StopReason::Crash as i32, + message: status.to_string(), + }); + } + Err(ref status) => { + start_crashed_cleanup( + &id, + Error::Start("runtime failed to start deployment".to_string()), + ); - cleanup(response); + error!(%status, "failed to start service"); + } + } } #[cfg(test)] mod tests { - use std::{net::SocketAddr, path::PathBuf, process::Command, sync::Arc, time::Duration}; + use std::{ + net::{Ipv4Addr, SocketAddr}, + path::PathBuf, + process::Command, + sync::Arc, + time::Duration, + }; use async_trait::async_trait; + use portpicker::pick_unused_port; use shuttle_common::storage_manager::ArtifactsStorageManager; - use shuttle_proto::runtime::StopResponse; + use shuttle_proto::{ + provisioner::{ + provisioner_server::{Provisioner, ProvisionerServer}, + DatabaseRequest, DatabaseResponse, + }, + runtime::{StopReason, SubscribeStopResponse}, + }; use tempdir::TempDir; use tokio::{ - sync::{broadcast, oneshot, Mutex}, + sync::{oneshot, Mutex}, time::sleep, }; - use tonic::{Response, Status}; + use tonic::transport::Server; use uuid::Uuid; use crate::{ @@ -353,12 +375,36 @@ mod tests { Ok(()) } + struct ProvisionerMock; + + #[async_trait] + impl Provisioner for ProvisionerMock { + async fn provision_database( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + panic!("no run tests should request a db"); + } + } + fn get_runtime_manager() -> Arc> { + let provisioner_addr = + SocketAddr::new(Ipv4Addr::LOCALHOST.into(), pick_unused_port().unwrap()); + let mock = ProvisionerMock; + + tokio::spawn(async move { + Server::builder() + .add_service(ProvisionerServer::new(mock)) + .serve(provisioner_addr) + .await + .unwrap(); + }); + let tmp_dir = TempDir::new("shuttle_run_test").unwrap(); let path = tmp_dir.into_path(); let (tx, _rx) = crossbeam_channel::unbounded(); - RuntimeManager::new(path, "http://localhost:5000".to_string(), tx) + RuntimeManager::new(path, format!("http://{}", provisioner_addr), tx) } #[derive(Clone)] @@ -398,25 +444,25 @@ mod tests { async fn can_be_killed() { let (built, storage_manager) = make_so_and_built("sleep-async"); let id = built.id; - let (kill_send, kill_recv) = broadcast::channel(1); + let runtime_manager = get_runtime_manager(); let (cleanup_send, cleanup_recv) = oneshot::channel(); - let handle_cleanup = |result: std::result::Result, Status>| { - assert!( - result.unwrap().into_inner().success, - "handle should have been cancelled", - ); - cleanup_send.send(()).unwrap(); + let handle_cleanup = |response: SubscribeStopResponse| match ( + StopReason::from_i32(response.reason).unwrap(), + response.message, + ) { + (StopReason::Request, mes) if mes.is_empty() => cleanup_send.send(()).unwrap(), + _ => panic!("expected stop due to request"), }; + let secret_getter = get_secret_getter(); built .handle( storage_manager, secret_getter, - get_runtime_manager(), + runtime_manager.clone(), StubDeploymentUpdater, - kill_recv, kill_old_deployments(), handle_cleanup, ) @@ -427,7 +473,7 @@ mod tests { sleep(Duration::from_secs(1)).await; // Send kill signal - kill_send.send(id).unwrap(); + assert!(runtime_manager.lock().await.kill(&id).await); tokio::select! { _ = sleep(Duration::from_secs(1)) => panic!("cleanup should have been called"), @@ -439,27 +485,25 @@ mod tests { #[tokio::test] async fn self_stop() { let (built, storage_manager) = make_so_and_built("sleep-async"); - let (_kill_send, kill_recv) = broadcast::channel(1); + let runtime_manager = get_runtime_manager(); let (cleanup_send, cleanup_recv) = oneshot::channel(); - let handle_cleanup = |_result: std::result::Result, Status>| { - // let result = result.unwrap(); - // assert!( - // result.is_ok(), - // "did not expect error from self stopping service: {}", - // result.unwrap_err() - // ); - cleanup_send.send(()).unwrap(); + let handle_cleanup = |response: SubscribeStopResponse| match ( + StopReason::from_i32(response.reason).unwrap(), + response.message, + ) { + (StopReason::End, mes) if mes.is_empty() => cleanup_send.send(()).unwrap(), + _ => panic!("expected stop due to self end"), }; + let secret_getter = get_secret_getter(); built .handle( storage_manager, secret_getter, - get_runtime_manager(), + runtime_manager.clone(), StubDeploymentUpdater, - kill_recv, kill_old_deployments(), handle_cleanup, ) @@ -470,33 +514,36 @@ mod tests { _ = sleep(Duration::from_secs(5)) => panic!("cleanup should have been called as service stopped on its own"), Ok(()) = cleanup_recv => {}, } + + // Prevent the runtime manager from dropping earlier, which will kill the processes it manages + drop(runtime_manager); } // Test for panics in Service::bind #[tokio::test] async fn panic_in_bind() { let (built, storage_manager) = make_so_and_built("bind-panic"); - let (_kill_send, kill_recv) = broadcast::channel(1); - let (cleanup_send, cleanup_recv): (oneshot::Sender<()>, _) = oneshot::channel(); - - let handle_cleanup = |_result: std::result::Result, Status>| { - // let result = result.unwrap(); - // assert!( - // matches!(result, Err(shuttle_service::Error::BindPanic(ref msg)) if msg == "panic in bind"), - // "expected inner error from handle: {:?}", - // result - // ); - cleanup_send.send(()).unwrap(); + let runtime_manager = get_runtime_manager(); + let (cleanup_send, cleanup_recv) = oneshot::channel(); + + let handle_cleanup = |response: SubscribeStopResponse| match ( + StopReason::from_i32(response.reason).unwrap(), + response.message, + ) { + (StopReason::Crash, mes) if mes.contains("Panic occurred in `Service::bind`") => { + cleanup_send.send(()).unwrap() + } + (_, mes) => panic!("expected stop due to crash: {mes}"), }; + let secret_getter = get_secret_getter(); built .handle( storage_manager, secret_getter, - get_runtime_manager(), + runtime_manager.clone(), StubDeploymentUpdater, - kill_recv, kill_old_deployments(), handle_cleanup, ) @@ -507,34 +554,49 @@ mod tests { _ = sleep(Duration::from_secs(5)) => panic!("cleanup should have been called as service handle stopped after panic"), Ok(()) = cleanup_recv => {} } + + // Prevent the runtime manager from dropping earlier, which will kill the processes it manages + drop(runtime_manager); } // Test for panics in the main function #[tokio::test] async fn panic_in_main() { let (built, storage_manager) = make_so_and_built("main-panic"); - let (_kill_send, kill_recv) = broadcast::channel(1); + let runtime_manager = get_runtime_manager(); + let (cleanup_send, cleanup_recv) = oneshot::channel(); + + let handle_cleanup = |response: SubscribeStopResponse| match ( + StopReason::from_i32(response.reason).unwrap(), + response.message, + ) { + (StopReason::Crash, mes) if mes.contains("Panic occurred in shuttle_service::main") => { + cleanup_send.send(()).unwrap() + } + (_, mes) => panic!("expected stop due to crash: {mes}"), + }; - let handle_cleanup = |_result| panic!("the service shouldn't even start"); let secret_getter = get_secret_getter(); - let result = built + built .handle( storage_manager, secret_getter, - get_runtime_manager(), + runtime_manager.clone(), StubDeploymentUpdater, - kill_recv, kill_old_deployments(), handle_cleanup, ) - .await; + .await + .unwrap(); - assert!( - matches!(result, Err(Error::Run(shuttle_service::Error::BuildPanic(ref msg))) if msg == "main panic"), - "expected inner error from main: {:?}", - result - ); + tokio::select! { + _ = sleep(Duration::from_secs(5)) => panic!("cleanup should have been called"), + Ok(()) = cleanup_recv => {} + } + + // Prevent the runtime manager from dropping earlier, which will kill the processes it manages + drop(runtime_manager); } #[tokio::test] @@ -546,7 +608,6 @@ mod tests { tracing_context: Default::default(), is_next: false, }; - let (_kill_send, kill_recv) = broadcast::channel(1); let handle_cleanup = |_result| panic!("no service means no cleanup"); let secret_getter = get_secret_getter(); @@ -558,7 +619,6 @@ mod tests { secret_getter, get_runtime_manager(), StubDeploymentUpdater, - kill_recv, kill_old_deployments(), handle_cleanup, ) diff --git a/deployer/src/error.rs b/deployer/src/error.rs index e81eae92b..e5e407276 100644 --- a/deployer/src/error.rs +++ b/deployer/src/error.rs @@ -32,6 +32,8 @@ pub enum Error { GatewayClient(#[from] gateway_client::Error), #[error("Failed to get runtime: {0}")] Runtime(#[source] anyhow::Error), + #[error("Failed to call start on runtime: {0}")] + Start(String), } #[derive(Error, Debug)] diff --git a/deployer/src/runtime_manager.rs b/deployer/src/runtime_manager.rs index 1693fa0c9..c416556b2 100644 --- a/deployer/src/runtime_manager.rs +++ b/deployer/src/runtime_manager.rs @@ -1,10 +1,13 @@ use std::{convert::TryInto, path::PathBuf, sync::Arc}; use anyhow::Context; -use shuttle_proto::runtime::{self, runtime_client::RuntimeClient, SubscribeLogsRequest}; +use shuttle_proto::runtime::{ + self, runtime_client::RuntimeClient, StopRequest, SubscribeLogsRequest, +}; use tokio::{process, sync::Mutex}; use tonic::transport::Channel; use tracing::{info, instrument, trace}; +use uuid::Uuid; use crate::deployment::deploy_layer; @@ -41,13 +44,12 @@ impl RuntimeManager { pub async fn get_runtime_client( &mut self, is_next: bool, - ) -> anyhow::Result<&mut RuntimeClient> { + ) -> anyhow::Result> { if is_next { Self::get_runtime_client_helper( &mut self.next, &mut self.next_process, is_next, - 6002, self.artifacts_path.clone(), &self.provisioner_address, self.log_sender.clone(), @@ -58,7 +60,6 @@ impl RuntimeManager { &mut self.legacy, &mut self.legacy_process, is_next, - 6001, self.artifacts_path.clone(), &self.provisioner_address, self.log_sender.clone(), @@ -67,22 +68,52 @@ impl RuntimeManager { } } + /// Send a kill / stop signal for a deployment to any runtimes currently running + pub async fn kill(&mut self, id: &Uuid) -> bool { + let success_legacy = if let Some(legacy_client) = &mut self.legacy { + trace!(%id, "sending stop signal to legacy for deployment"); + + let stop_request = tonic::Request::new(StopRequest {}); + let response = legacy_client.stop(stop_request).await.unwrap(); + + response.into_inner().success + } else { + trace!("no legacy client running"); + true + }; + + let success_next = if let Some(next_client) = &mut self.next { + trace!(%id, "sending stop signal to next for deployment"); + + let stop_request = tonic::Request::new(StopRequest {}); + let response = next_client.stop(stop_request).await.unwrap(); + + response.into_inner().success + } else { + trace!("no next client running"); + true + }; + + success_legacy && success_next + } + #[instrument(skip(runtime_option, process_option, log_sender))] - async fn get_runtime_client_helper<'a>( - runtime_option: &'a mut Option>, + async fn get_runtime_client_helper( + runtime_option: &mut Option>, process_option: &mut Option>>, is_next: bool, - port: u16, artifacts_path: PathBuf, provisioner_address: &str, log_sender: crossbeam_channel::Sender, - ) -> anyhow::Result<&'a mut RuntimeClient> { + ) -> anyhow::Result> { if let Some(runtime_client) = runtime_option { trace!("returning previous client"); - Ok(runtime_client) + Ok(runtime_client.clone()) } else { trace!("making new client"); + let port = portpicker::pick_unused_port().context("failed to find available port")?; + let get_runtime_executable = || { if cfg!(debug_assertions) { // If we're running deployer natively, install shuttle-runtime using the @@ -135,11 +166,11 @@ impl RuntimeManager { } }); - *runtime_option = Some(runtime_client); + *runtime_option = Some(runtime_client.clone()); *process_option = Some(Arc::new(std::sync::Mutex::new(process))); // Safe to unwrap as it was just set - Ok(runtime_option.as_mut().unwrap()) + Ok(runtime_client) } } } diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 46623e80f..3fde2914a 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -39,9 +39,9 @@ strum = { version = "0.24.1", features = ["derive"] } tokio = { version = "1.22.0", features = [ "full" ] } tower = { version = "0.4.13", features = [ "steer" ] } tower-http = { version = "0.3.4", features = ["trace"] } -tracing = { workspace = true } +tracing = { workspace = true, features = ["default"] } tracing-opentelemetry = "0.18.0" -tracing-subscriber = { workspace = true, features = ["env-filter"] } +tracing-subscriber = { workspace = true, features = ["default", "env-filter"] } ttl_cache = "0.5.1" uuid = { workspace = true, features = [ "v4" ] } diff --git a/proto/runtime.proto b/proto/runtime.proto index 567a3fa58..d75505419 100644 --- a/proto/runtime.proto +++ b/proto/runtime.proto @@ -13,6 +13,9 @@ service Runtime { // Stop a started service rpc Stop(StopRequest) returns (StopResponse); + // Channel to notify a service has been stopped + rpc SubscribeStop(SubscribeStopRequest) returns (stream SubscribeStopResponse); + // Subscribe to runtime logs rpc SubscribeLogs(SubscribeLogsRequest) returns (stream LogItem); } @@ -52,6 +55,27 @@ message StopResponse { bool success = 1; } +message SubscribeStopRequest {} + +message SubscribeStopResponse { + // Reason the service has stopped + StopReason reason = 1; + + // Any extra message to go with the reason. If there are any + string message = 2; +} + +enum StopReason { + // User requested this stop + Request = 0; + + // Service stopped by itself + End = 1; + + // Service crashed + Crash = 2; +} + message SubscribeLogsRequest {} message LogItem { diff --git a/provisioner/Cargo.toml b/provisioner/Cargo.toml index 02b1b8f50..89ced2ac6 100644 --- a/provisioner/Cargo.toml +++ b/provisioner/Cargo.toml @@ -19,8 +19,8 @@ sqlx = { version = "0.6.2", features = ["postgres", "runtime-tokio-native-tls"] thiserror = { workspace = true } tokio = { version = "1.22.0", features = ["macros", "rt-multi-thread"] } tonic = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = ["fmt"] } +tracing = { workspace = true, features = ["default"] } +tracing-subscriber = { workspace = true, features = ["default", "fmt"] } [dependencies.shuttle-proto] workspace = true diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 36f83ee99..4fdfe7d73 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -23,8 +23,8 @@ thiserror = { workspace = true } tokio = { version = "1.22.0", features = ["full"] } tokio-stream = "0.1.11" tonic = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } +tracing = { workspace = true, features = ["default"] } +tracing-subscriber = { workspace = true, features = ["default", "env-filter", "fmt"] } uuid = { workspace = true, features = ["v4"] } # TODO: bump these crates to 6.0 when we bump rust to >= 1.66 diff --git a/runtime/src/legacy/mod.rs b/runtime/src/legacy/mod.rs index cad51259c..e87ad02e3 100644 --- a/runtime/src/legacy/mod.rs +++ b/runtime/src/legacy/mod.rs @@ -21,13 +21,16 @@ use shuttle_proto::{ runtime::{ self, runtime_server::{Runtime, RuntimeServer}, - LoadRequest, LoadResponse, StartRequest, StartResponse, StopRequest, StopResponse, - SubscribeLogsRequest, + LoadRequest, LoadResponse, StartRequest, StartResponse, StopReason, StopRequest, + StopResponse, SubscribeLogsRequest, SubscribeStopRequest, SubscribeStopResponse, }, }; use shuttle_service::{Factory, Service, ServiceName}; -use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use tokio::sync::oneshot; +use tokio::sync::{broadcast, oneshot}; +use tokio::sync::{ + broadcast::Sender, + mpsc::{self, UnboundedReceiver, UnboundedSender}, +}; use tokio_stream::wrappers::ReceiverStream; use tonic::{ transport::{Endpoint, Server}, @@ -70,6 +73,7 @@ pub struct Legacy { // Mutexes are for interior mutability logs_rx: Mutex>>, logs_tx: UnboundedSender, + stopped_tx: Sender<(StopReason, String)>, provisioner_address: Endpoint, kill_tx: Mutex>>, storage_manager: M, @@ -80,10 +84,12 @@ pub struct Legacy { impl Legacy { pub fn new(provisioner_address: Endpoint, loader: L, storage_manager: M) -> Self { let (tx, rx) = mpsc::unbounded_channel(); + let (stopped_tx, _stopped_rx) = broadcast::channel(10); Self { logs_rx: Mutex::new(Some(rx)), logs_tx: tx, + stopped_tx, kill_tx: Mutex::new(None), provisioner_address, storage_manager, @@ -196,7 +202,12 @@ where *self.kill_tx.lock().unwrap() = Some(kill_tx); // start service as a background task with a kill receiver - tokio::spawn(run_until_stopped(service, service_address, kill_rx)); + tokio::spawn(run_until_stopped( + service, + service_address, + self.stopped_tx.clone(), + kill_rx, + )); let message = StartResponse { success: true }; @@ -241,26 +252,60 @@ where Err(Status::internal("failed to stop deployment")) } } + + type SubscribeStopStream = ReceiverStream>; + + async fn subscribe_stop( + &self, + _request: Request, + ) -> Result, Status> { + let mut stopped_rx = self.stopped_tx.subscribe(); + let (tx, rx) = mpsc::channel(1); + + // Move the stop channel into a stream to be returned + tokio::spawn(async move { + while let Ok((reason, message)) = stopped_rx.recv().await { + tx.send(Ok(SubscribeStopResponse { + reason: reason as i32, + message, + })) + .await + .unwrap(); + } + }); + + Ok(Response::new(ReceiverStream::new(rx))) + } } /// Run the service until a stop signal is received -#[instrument(skip(service, kill_rx))] +#[instrument(skip(service, stopped_tx, kill_rx))] async fn run_until_stopped( // service: LoadedService, service: impl Service, addr: SocketAddr, + stopped_tx: tokio::sync::broadcast::Sender<(StopReason, String)>, kill_rx: tokio::sync::oneshot::Receiver, ) { trace!("starting deployment on {}", &addr); tokio::select! { - _ = service.bind(addr) => { - trace!("deployment stopped on {}", &addr); + res = service.bind(addr) => { + match res { + Ok(_) => { + stopped_tx.send((StopReason::End, String::new())).unwrap(); + } + Err(error) => { + stopped_tx.send((StopReason::Crash, error.to_string())).unwrap(); + } + } }, message = kill_rx => { match message { - Ok(msg) => trace!("{msg}"), + Ok(_) => { + stopped_tx.send((StopReason::Request, String::new())).unwrap(); + } Err(_) => trace!("the sender dropped") - } + }; } } } diff --git a/runtime/src/next/mod.rs b/runtime/src/next/mod.rs index 41f50eaac..e22d0a9f0 100644 --- a/runtime/src/next/mod.rs +++ b/runtime/src/next/mod.rs @@ -18,13 +18,14 @@ use shuttle_common::wasm::{Bytesable, Log, RequestWrapper, ResponseWrapper}; use shuttle_proto::runtime::runtime_server::Runtime; use shuttle_proto::runtime::{ self, LoadRequest, LoadResponse, StartRequest, StartResponse, StopRequest, StopResponse, - SubscribeLogsRequest, + SubscribeLogsRequest, SubscribeStopRequest, SubscribeStopResponse, }; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; use tonic::Status; use tracing::{error, trace}; +use uuid::Uuid; use wasi_common::file::FileCaps; use wasmtime::{Engine, Linker, Module, Store}; use wasmtime_wasi::sync::net::UnixStream as WasiUnixStream; @@ -167,8 +168,19 @@ impl Runtime for AxumWasm { )) } } -} + type SubscribeStopStream = ReceiverStream>; + + async fn subscribe_stop( + &self, + _request: tonic::Request, + ) -> Result, Status> { + // Next does not really have a stopped state. Endpoints are loaded if and when needed until a request is done + let (_tx, rx) = mpsc::channel(1); + + Ok(tonic::Response::new(ReceiverStream::new(rx))) + } +} struct RouterBuilder { engine: Engine, linker: Linker, @@ -421,7 +433,6 @@ pub mod tests { let router = RouterBuilder::new() .unwrap() - .src("axum.wasm") .src("tests/resources/axum-wasm-expanded/target/wasm32-wasi/debug/shuttle_axum_expanded.wasm") .build() .unwrap();