From cc072b2bb392126c075aa80bf03fd482ddda4f6d Mon Sep 17 00:00:00 2001 From: Pieter Date: Mon, 16 Jan 2023 13:29:31 +0200 Subject: [PATCH] feat: deployer next (#575) * feat: propagate next runtime * feat: store is_next in DB * feat: runtime manager to allow deployer to start up both runtimes * feat: make sure tests run * refactor: better migration query * refactor: handle runtime errors better * feat: shutdown runtimes * bug: missing so * bug: stop services * bug: ffi and runtime manager not living long enough * bug: missing so error * refactor: run cleanups * refactor: clippy suggestions --- cargo-shuttle/src/lib.rs | 1 + codegen/src/next/mod.rs | 1 + deployer/migrations/0001_next.sql | 1 + deployer/src/deployment/deploy_layer.rs | 96 +++------ deployer/src/deployment/mod.rs | 43 ++-- deployer/src/deployment/queue.rs | 52 +++-- deployer/src/deployment/run.rs | 248 ++++++++++++++---------- deployer/src/error.rs | 6 +- deployer/src/handlers/mod.rs | 1 + deployer/src/lib.rs | 17 +- deployer/src/main.rs | 36 +--- deployer/src/persistence/deployment.rs | 17 +- deployer/src/persistence/mod.rs | 78 ++++++-- deployer/src/runtime_manager.rs | 127 ++++++++++++ proto/src/lib.rs | 5 +- runtime/src/args.rs | 4 + runtime/src/legacy/mod.rs | 10 + runtime/src/main.rs | 2 +- service/src/loader.rs | 2 + tmp/axum-wasm/Cargo.toml | 4 + tmp/axum-wasm/src/lib.rs | 24 +++ 21 files changed, 525 insertions(+), 250 deletions(-) create mode 100644 deployer/migrations/0001_next.sql create mode 100644 deployer/src/runtime_manager.rs diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index e6851fb4c..16c7866f9 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -413,6 +413,7 @@ impl Shuttle { is_wasm, runtime::StorageManagerType::WorkingDir(working_directory.to_path_buf()), &format!("http://localhost:{}", run_args.port + 1), + run_args.port + 2, ) .await .map_err(|err| { diff --git a/codegen/src/next/mod.rs b/codegen/src/next/mod.rs index 3c0f93b3e..5e8ad4e88 100644 --- a/codegen/src/next/mod.rs +++ b/codegen/src/next/mod.rs @@ -259,6 +259,7 @@ pub(crate) fn wasi_bindings(app: App) -> proc_macro2::TokenStream { quote!( #app + #[cfg(not(test))] #[no_mangle] #[allow(non_snake_case)] pub extern "C" fn __SHUTTLE_Axum_call( diff --git a/deployer/migrations/0001_next.sql b/deployer/migrations/0001_next.sql new file mode 100644 index 000000000..42b89c217 --- /dev/null +++ b/deployer/migrations/0001_next.sql @@ -0,0 +1 @@ +ALTER TABLE deployments ADD COLUMN is_next BOOLEAN DEFAULT 0 NOT NULL; diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index afe795a72..5b56dcba0 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -23,8 +23,8 @@ use chrono::{DateTime, Utc}; use serde_json::json; use shuttle_common::{tracing::JsonVisitor, STATE_MESSAGE}; use shuttle_proto::runtime; -use std::{net::SocketAddr, str::FromStr, time::SystemTime}; -use tracing::{error, field::Visit, span, warn, Metadata, Subscriber}; +use std::{str::FromStr, time::SystemTime}; +use tracing::{field::Visit, span, warn, Metadata, Subscriber}; use tracing_subscriber::Layer; use uuid::Uuid; @@ -63,8 +63,6 @@ pub struct Log { pub fields: serde_json::Value, pub r#type: LogType, - - pub address: Option, } impl From for persistence::Log { @@ -106,23 +104,10 @@ impl From for shuttle_common::LogItem { impl From for DeploymentState { fn from(log: Log) -> Self { - let address = if let Some(address_str) = log.address { - match SocketAddr::from_str(&address_str) { - Ok(address) => Some(address), - Err(err) => { - error!(error = %err, "failed to convert to [SocketAddr]"); - None - } - } - } else { - None - }; - Self { id: log.id, state: log.state, last_update: log.timestamp, - address, } } } @@ -139,7 +124,6 @@ impl From for Log { target: log.target, fields: serde_json::from_slice(&log.fields).unwrap(), r#type: LogType::Event, - address: None, } } } @@ -230,7 +214,6 @@ where .unwrap_or_else(|| metadata.target().to_string()), fields: serde_json::Value::Object(visitor.fields), r#type: LogType::Event, - address: None, }); break; } @@ -274,7 +257,6 @@ where target: metadata.target().to_string(), fields: Default::default(), r#type: LogType::State, - address: details.address.clone(), }); extensions.insert::(details); @@ -286,7 +268,6 @@ where struct ScopeDetails { id: Uuid, state: State, - address: Option, } impl From<&tracing::Level> for LogLevel { @@ -314,9 +295,6 @@ impl NewStateVisitor { /// Field containing the deployment state identifier const STATE_IDENT: &'static str = "state"; - /// Field containing the deployment address identifier - const ADDRESS_IDENT: &'static str = "address"; - fn is_valid(metadata: &Metadata) -> bool { metadata.is_span() && metadata.fields().field(Self::ID_IDENT).is_some() @@ -330,8 +308,6 @@ impl Visit for NewStateVisitor { self.details.state = State::from_str(&format!("{value:?}")).unwrap_or_default(); } else if field.name() == Self::ID_IDENT { self.details.id = Uuid::try_parse(&format!("{value:?}")).unwrap_or_default(); - } else if field.name() == Self::ADDRESS_IDENT { - self.details.address = Some(format!("{value:?}")); } } } @@ -340,17 +316,18 @@ impl Visit for NewStateVisitor { mod tests { use std::{ fs::read_dir, + net::SocketAddr, path::PathBuf, sync::{Arc, Mutex}, time::Duration, }; + use crate::{persistence::DeploymentUpdater, RuntimeManager}; use axum::body::Bytes; use ctor::ctor; use flate2::{write::GzEncoder, Compression}; - use shuttle_proto::runtime::runtime_client::RuntimeClient; + use tempdir::TempDir; use tokio::{select, time::sleep}; - use tonic::transport::Channel; use tracing_subscriber::prelude::*; use uuid::Uuid; @@ -383,7 +360,6 @@ mod tests { struct StateLog { id: Uuid, state: State, - has_address: bool, } impl From for StateLog { @@ -391,7 +367,6 @@ mod tests { Self { id: log.id, state: log.state, - has_address: log.address.is_some(), } } } @@ -423,10 +398,12 @@ mod tests { } } - async fn get_runtime_client() -> RuntimeClient { - RuntimeClient::connect("http://127.0.0.1:6001") - .await - .unwrap() + fn get_runtime_manager() -> Arc> { + let tmp_dir = TempDir::new("shuttle_run_test").unwrap(); + let path = tmp_dir.into_path(); + let (tx, _rx) = crossbeam_channel::unbounded(); + + RuntimeManager::new(path, "http://provisioner:8000".to_string(), tx) } #[async_trait::async_trait] @@ -449,6 +426,22 @@ mod tests { } } + #[derive(Clone)] + struct StubDeploymentUpdater; + + #[async_trait::async_trait] + impl DeploymentUpdater for StubDeploymentUpdater { + type Err = std::io::Error; + + async fn set_address(&self, _id: &Uuid, _address: &SocketAddr) -> Result<(), Self::Err> { + Ok(()) + } + + async fn set_is_next(&self, _id: &Uuid, _is_next: bool) -> Result<(), Self::Err> { + Ok(()) + } + } + #[derive(Clone)] struct StubActiveDeploymentGetter; @@ -527,27 +520,22 @@ mod tests { StateLog { id, state: State::Queued, - has_address: false, }, StateLog { id, state: State::Building, - has_address: false, }, StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Running, - has_address: true, }, ] ); @@ -577,32 +565,26 @@ mod tests { StateLog { id, state: State::Queued, - has_address: false, }, StateLog { id, state: State::Building, - has_address: false, }, StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Running, - has_address: true, }, StateLog { id, state: State::Stopped, - has_address: false, }, ] ); @@ -639,32 +621,26 @@ mod tests { StateLog { id, state: State::Queued, - has_address: false, }, StateLog { id, state: State::Building, - has_address: false, }, StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Running, - has_address: true, }, StateLog { id, state: State::Completed, - has_address: false, }, ] ); @@ -712,32 +688,26 @@ mod tests { StateLog { id, state: State::Queued, - has_address: false, }, StateLog { id, state: State::Building, - has_address: false, }, StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Running, - has_address: true, }, StateLog { id, state: State::Crashed, - has_address: false, }, ] ); @@ -785,27 +755,22 @@ mod tests { StateLog { id, state: State::Queued, - has_address: false, }, StateLog { id, state: State::Building, - has_address: false, }, StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Crashed, - has_address: false, }, ] ); @@ -833,6 +798,7 @@ mod tests { service_name: "run-test".to_string(), service_id: Uuid::new_v4(), tracing_context: Default::default(), + is_next: false, }) .await; @@ -854,17 +820,14 @@ mod tests { StateLog { id, state: State::Built, - has_address: false, }, StateLog { id, state: State::Loading, - has_address: true, }, StateLog { id, state: State::Crashed, - has_address: false, }, ] ); @@ -905,7 +868,8 @@ mod tests { .active_deployment_getter(StubActiveDeploymentGetter) .artifacts_path(PathBuf::from("/tmp")) .secret_getter(StubSecretGetter) - .runtime(get_runtime_client().await) + .runtime(get_runtime_manager()) + .deployment_updater(StubDeploymentUpdater) .queue_client(StubBuildQueueClient) .build() } diff --git a/deployer/src/deployment/mod.rs b/deployer/src/deployment/mod.rs index 0048573ae..121a4c926 100644 --- a/deployer/src/deployment/mod.rs +++ b/deployer/src/deployment/mod.rs @@ -3,18 +3,19 @@ pub mod gateway_client; mod queue; mod run; -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; pub use queue::Queued; pub use run::{ActiveDeploymentsGetter, Built}; use shuttle_common::storage_manager::ArtifactsStorageManager; -use shuttle_proto::runtime::runtime_client::RuntimeClient; -use tonic::transport::Channel; use tracing::{instrument, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use crate::persistence::{SecretGetter, SecretRecorder, State}; -use tokio::sync::{broadcast, mpsc}; +use crate::{ + persistence::{DeploymentUpdater, SecretGetter, SecretRecorder, State}, + RuntimeManager, +}; +use tokio::sync::{broadcast, mpsc, Mutex}; use uuid::Uuid; use self::{deploy_layer::LogRecorder, gateway_client::BuildQueueClient}; @@ -23,21 +24,23 @@ const QUEUE_BUFFER_SIZE: usize = 100; const RUN_BUFFER_SIZE: usize = 100; const KILL_BUFFER_SIZE: usize = 10; -pub struct DeploymentManagerBuilder { +pub struct DeploymentManagerBuilder { build_log_recorder: Option, secret_recorder: Option, active_deployment_getter: Option, artifacts_path: Option, - runtime_client: Option>, + runtime_manager: Option>>, + deployment_updater: Option, secret_getter: Option, queue_client: Option, } -impl DeploymentManagerBuilder +impl DeploymentManagerBuilder where LR: LogRecorder, SR: SecretRecorder, ADG: ActiveDeploymentsGetter, + DU: DeploymentUpdater, SG: SecretGetter, QC: BuildQueueClient, { @@ -77,8 +80,14 @@ where self } - pub fn runtime(mut self, runtime_client: RuntimeClient) -> Self { - self.runtime_client = Some(runtime_client); + pub fn runtime(mut self, runtime_manager: Arc>) -> Self { + self.runtime_manager = Some(runtime_manager); + + self + } + + pub fn deployment_updater(mut self, deployment_updater: DU) -> Self { + self.deployment_updater = Some(deployment_updater); self } @@ -97,7 +106,10 @@ where .expect("an active deployment getter to be set"); let artifacts_path = self.artifacts_path.expect("artifacts path to be set"); let queue_client = self.queue_client.expect("a queue client to be set"); - let runtime_client = self.runtime_client.expect("a runtime client to be set"); + let runtime_manager = self.runtime_manager.expect("a runtime manager to be set"); + let deployment_updater = self + .deployment_updater + .expect("a deployment updater to be set"); let secret_getter = self.secret_getter.expect("a secret getter to be set"); let (queue_send, queue_recv) = mpsc::channel(QUEUE_BUFFER_SIZE); @@ -110,6 +122,7 @@ where tokio::spawn(queue::task( queue_recv, run_send_clone, + deployment_updater.clone(), build_log_recorder, secret_recorder, storage_manager.clone(), @@ -117,7 +130,8 @@ where )); tokio::spawn(run::task( run_recv, - runtime_client, + runtime_manager, + deployment_updater, kill_send.clone(), active_deployment_getter, secret_getter, @@ -158,13 +172,14 @@ pub struct DeploymentManager { impl DeploymentManager { /// Create a new deployment manager. Manages one or more 'pipelines' for /// processing service building, loading, and deployment. - pub fn builder() -> DeploymentManagerBuilder { + pub fn builder() -> DeploymentManagerBuilder { DeploymentManagerBuilder { build_log_recorder: None, secret_recorder: None, active_deployment_getter: None, artifacts_path: None, - runtime_client: None, + runtime_manager: None, + deployment_updater: None, secret_getter: None, queue_client: None, } diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index c3bcd1e52..9e871986c 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -2,7 +2,7 @@ use super::deploy_layer::{Log, LogRecorder, LogType}; use super::gateway_client::BuildQueueClient; use super::{Built, QueueReceiver, RunSender, State}; use crate::error::{Error, Result, TestError}; -use crate::persistence::{LogLevel, SecretRecorder}; +use crate::persistence::{DeploymentUpdater, LogLevel, SecretRecorder}; use shuttle_common::storage_manager::{ArtifactsStorageManager, StorageManager}; use cargo::util::interning::InternedString; @@ -34,6 +34,7 @@ use tokio::fs; pub async fn task( mut recv: QueueReceiver, run_send: RunSender, + deployment_updater: impl DeploymentUpdater, log_recorder: impl LogRecorder, secret_recorder: impl SecretRecorder, storage_manager: ArtifactsStorageManager, @@ -46,6 +47,7 @@ pub async fn task( info!("Queued deployment at the front of the queue: {id}"); + let deployment_updater = deployment_updater.clone(); let run_send_cloned = run_send.clone(); let log_recorder = log_recorder.clone(); let secret_recorder = secret_recorder.clone(); @@ -71,7 +73,12 @@ pub async fn task( } match queued - .handle(storage_manager, log_recorder, secret_recorder) + .handle( + storage_manager, + deployment_updater, + log_recorder, + secret_recorder, + ) .await { Ok(built) => promote_to_run(built, run_send_cloned).await, @@ -144,10 +151,11 @@ pub struct Queued { } impl Queued { - #[instrument(skip(self, storage_manager, log_recorder, secret_recorder), fields(id = %self.id, state = %State::Building))] + #[instrument(skip(self, storage_manager, deployment_updater, log_recorder, secret_recorder), fields(id = %self.id, state = %State::Building))] async fn handle( self, storage_manager: ArtifactsStorageManager, + deployment_updater: impl DeploymentUpdater, log_recorder: impl LogRecorder, secret_recorder: impl SecretRecorder, ) -> Result { @@ -180,7 +188,6 @@ impl Queued { target: String::new(), fields: json!({ "build_line": line }), r#type: LogType::Event, - address: None, }, message => Log { id, @@ -192,7 +199,6 @@ impl Queued { target: String::new(), fields: serde_json::to_value(message).unwrap(), r#type: LogType::Event, - address: None, }, }; log_recorder.record(log); @@ -200,7 +206,7 @@ impl Queued { }); let project_path = project_path.canonicalize()?; - let so_path = build_deployment(self.id, &project_path, tx.clone()).await?; + let runtime = build_deployment(self.id, &project_path, tx.clone()).await?; if self.will_run_tests { info!( @@ -213,13 +219,21 @@ impl Queued { info!("Moving built library"); - store_lib(&storage_manager, so_path, &self.id).await?; + store_lib(&storage_manager, &runtime, &self.id).await?; + + let is_next = matches!(runtime, Runtime::Next(_)); + + deployment_updater + .set_is_next(&id, is_next) + .await + .map_err(|e| Error::Build(Box::new(e)))?; let built = Built { id: self.id, service_name: self.service_name, service_id: self.service_id, tracing_context: Default::default(), + is_next, }; Ok(built) @@ -310,15 +324,10 @@ async fn build_deployment( deployment_id: Uuid, project_path: &Path, tx: crossbeam_channel::Sender, -) -> Result { - let runtime_path = build_crate(deployment_id, project_path, true, tx) +) -> Result { + build_crate(deployment_id, project_path, true, tx) .await - .map_err(|e| Error::Build(e.into()))?; - - match runtime_path { - Runtime::Legacy(so_path) => Ok(so_path), - Runtime::Next(_) => todo!(), - } + .map_err(|e| Error::Build(e.into())) } #[instrument(skip(project_path, tx))] @@ -381,12 +390,17 @@ async fn run_pre_deploy_tests( } /// Store 'so' file in the libs folder -#[instrument(skip(storage_manager, so_path, id))] +#[instrument(skip(storage_manager, runtime, id))] async fn store_lib( storage_manager: &ArtifactsStorageManager, - so_path: impl AsRef, + runtime: &Runtime, id: &Uuid, ) -> Result<()> { + let so_path = match runtime { + Runtime::Next(path) => path, + Runtime::Legacy(path) => path, + }; + let new_so_path = storage_manager.deployment_library_path(id)?; fs::rename(so_path, new_so_path).await?; @@ -399,6 +413,7 @@ mod tests { use std::{collections::BTreeMap, fs::File, io::Write, path::Path}; use shuttle_common::storage_manager::ArtifactsStorageManager; + use shuttle_service::loader::Runtime; use tempdir::TempDir; use tokio::fs; use uuid::Uuid; @@ -533,11 +548,12 @@ ff0e55bda1ff01000000000000000000e0079c01ff12a55500280000", let build_p = storage_manager.builds_path().unwrap(); let so_path = build_p.join("xyz.so"); + let runtime = Runtime::Legacy(so_path.clone()); let id = Uuid::new_v4(); fs::write(&so_path, "barfoo").await.unwrap(); - super::store_lib(&storage_manager, &so_path, &id) + super::store_lib(&storage_manager, &runtime, &id) .await .unwrap(); diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 7b14a4e7b..a9edebc1d 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -2,33 +2,36 @@ use std::{ collections::HashMap, net::{Ipv4Addr, SocketAddr}, path::PathBuf, - str::FromStr, + sync::Arc, }; use async_trait::async_trait; use opentelemetry::global; use portpicker::pick_unused_port; -use shuttle_common::project::ProjectName as ServiceName; use shuttle_common::storage_manager::ArtifactsStorageManager; -use shuttle_proto::runtime::{runtime_client::RuntimeClient, LoadRequest, StartRequest}; +use shuttle_proto::runtime::{ + runtime_client::RuntimeClient, LoadRequest, StartRequest, StopRequest, StopResponse, +}; -use tokio::task::JoinError; -use tonic::transport::Channel; -use tracing::{debug_span, error, info, instrument, trace, Instrument}; +use tokio::sync::Mutex; +use tonic::{transport::Channel, Response, Status}; +use tracing::{debug, debug_span, error, info, instrument, trace, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; use uuid::Uuid; use super::{KillReceiver, KillSender, RunReceiver, State}; use crate::{ error::{Error, Result}, - persistence::SecretGetter, + persistence::{DeploymentUpdater, SecretGetter}, + RuntimeManager, }; /// Run a task which takes runnable deploys from a channel and starts them up on our runtime /// A deploy is killed when it receives a signal from the kill channel pub async fn task( mut recv: RunReceiver, - runtime_client: RuntimeClient, + runtime_manager: Arc>, + deployment_updater: impl DeploymentUpdater, kill_send: KillSender, active_deployment_getter: impl ActiveDeploymentsGetter, secret_getter: impl SecretGetter, @@ -41,37 +44,27 @@ pub async fn task( info!("Built deployment at the front of run queue: {id}"); + let deployment_updater = deployment_updater.clone(); let kill_send = kill_send.clone(); let kill_recv = kill_send.subscribe(); let secret_getter = secret_getter.clone(); let storage_manager = storage_manager.clone(); - let _service_name = match ServiceName::from_str(&built.service_name) { - Ok(name) => name, - Err(err) => { - start_crashed_cleanup(&id, err); - continue; - } - }; - let old_deployments_killer = kill_old_deployments( built.service_id, id, active_deployment_getter.clone(), kill_send, ); - let cleanup = move |result: std::result::Result< - std::result::Result<(), shuttle_service::Error>, - JoinError, - >| match result { - Ok(inner) => match inner { - Ok(()) => completed_cleanup(&id), + let cleanup = move |result: std::result::Result, Status>| { + info!(response = ?result, "stop client response: "); + + match result { + Ok(_) => completed_cleanup(&id), Err(err) => crashed_cleanup(&id, err), - }, - Err(err) if err.is_cancelled() => stopped_cleanup(&id), - Err(err) => start_crashed_cleanup(&id, err), + } }; - let runtime_client = runtime_client.clone(); + let runtime_manager = runtime_manager.clone(); tokio::spawn(async move { let parent_cx = global::get_text_map_propagator(|propagator| { @@ -85,7 +78,8 @@ pub async fn task( .handle( storage_manager, secret_getter, - runtime_client, + runtime_manager, + deployment_updater, kill_recv, old_deployments_killer, cleanup, @@ -169,21 +163,21 @@ pub struct Built { pub service_name: String, pub service_id: Uuid, pub tracing_context: HashMap, + pub is_next: bool, } impl Built { - #[instrument(skip(self, storage_manager, secret_getter, runtime_client, kill_recv, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] + #[instrument(skip(self, storage_manager, secret_getter, runtime_manager, deployment_updater, kill_recv, kill_old_deployments, cleanup), fields(id = %self.id, state = %State::Loading))] #[allow(clippy::too_many_arguments)] async fn handle( self, storage_manager: ArtifactsStorageManager, secret_getter: impl SecretGetter, - runtime_client: RuntimeClient, + runtime_manager: Arc>, + deployment_updater: impl DeploymentUpdater, kill_recv: KillReceiver, kill_old_deployments: impl futures::Future>, - cleanup: impl FnOnce(std::result::Result, JoinError>) - + Send - + 'static, + cleanup: impl FnOnce(std::result::Result, Status>) + Send + 'static, ) -> Result<()> { let so_path = storage_manager.deployment_library_path(&self.id)?; @@ -197,6 +191,11 @@ impl Built { }; let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port); + let mut runtime_manager = runtime_manager.lock().await.clone(); + let runtime_client = runtime_manager + .get_runtime_client(self.is_next) + .await + .map_err(Error::Runtime)?; kill_old_deployments.await?; @@ -207,17 +206,27 @@ impl Built { self.service_id, so_path, secret_getter, - runtime_client.clone(), - ) - .await; - tokio::spawn(run( - self.id, - self.service_name, runtime_client, - address, - kill_recv, - cleanup, - )); + ) + .await?; + + // Move runtime manager to this thread so that the runtime lives long enough + tokio::spawn(async move { + let runtime_client = runtime_manager + .get_runtime_client(self.is_next) + .await + .unwrap(); + run( + self.id, + self.service_name, + runtime_client, + address, + deployment_updater, + kill_recv, + cleanup, + ) + .await + }); Ok(()) } @@ -228,8 +237,8 @@ async fn load( service_id: Uuid, so_path: PathBuf, secret_getter: impl SecretGetter, - mut runtime_client: RuntimeClient, -) { + runtime_client: &mut RuntimeClient, +) -> Result<()> { info!( "loading project from: {}", so_path.clone().into_os_string().into_string().unwrap() @@ -248,56 +257,81 @@ async fn load( service_name: service_name.clone(), secrets, }); - info!("loading service"); + + debug!("loading service"); let response = runtime_client.load(load_request).await; - if let Err(e) = response { - info!("failed to load service: {}", e); + match response { + Ok(response) => { + info!(response = ?response.into_inner(), "loading response: "); + Ok(()) + } + Err(error) => { + error!(%error, "failed to load service"); + Err(Error::Load(error.to_string())) + } } } -#[instrument(skip(runtime_client, _kill_recv, _cleanup), fields(state = %State::Running))] +#[instrument(skip(runtime_client, deployment_updater, kill_recv, cleanup), fields(state = %State::Running))] async fn run( id: Uuid, service_name: String, - mut runtime_client: RuntimeClient, + runtime_client: &mut RuntimeClient, address: SocketAddr, - _kill_recv: KillReceiver, - _cleanup: impl FnOnce(std::result::Result, JoinError>) - + Send - + 'static, + deployment_updater: impl DeploymentUpdater, + mut kill_recv: KillReceiver, + cleanup: impl FnOnce(std::result::Result, Status>) + Send + 'static, ) { + deployment_updater.set_address(&id, &address).await.unwrap(); + let start_request = tonic::Request::new(StartRequest { deployment_id: id.as_bytes().to_vec(), - service_name, + service_name: service_name.clone(), port: address.port() as u32, }); info!("starting service"); let response = runtime_client.start(start_request).await.unwrap(); - info!(response = ?response.into_inner(), "client response: "); + info!(response = ?response.into_inner(), "start client response: "); + + let mut response = Err(Status::unknown("not stopped yet")); + + while let Ok(kill_id) = kill_recv.recv().await { + if kill_id == id { + let stop_request = tonic::Request::new(StopRequest { + deployment_id: id.as_bytes().to_vec(), + service_name: service_name.clone(), + }); + response = runtime_client.stop(stop_request).await; + + break; + } + } + + cleanup(response); } #[cfg(test)] mod tests { - use std::{path::PathBuf, process::Command, time::Duration}; + use std::{net::SocketAddr, path::PathBuf, process::Command, sync::Arc, time::Duration}; use async_trait::async_trait; use shuttle_common::storage_manager::ArtifactsStorageManager; - use shuttle_proto::runtime::runtime_client::RuntimeClient; + use shuttle_proto::runtime::StopResponse; use tempdir::TempDir; use tokio::{ - sync::{broadcast, oneshot}, - task::JoinError, + sync::{broadcast, oneshot, Mutex}, time::sleep, }; - use tonic::transport::Channel; + use tonic::{Response, Status}; use uuid::Uuid; use crate::{ error::Error, - persistence::{Secret, SecretGetter}, + persistence::{DeploymentUpdater, Secret, SecretGetter}, + RuntimeManager, }; use super::Built; @@ -315,10 +349,12 @@ mod tests { Ok(()) } - async fn get_runtime_client() -> RuntimeClient { - RuntimeClient::connect("http://127.0.0.1:6001") - .await - .unwrap() + fn get_runtime_manager() -> Arc> { + let tmp_dir = TempDir::new("shuttle_run_test").unwrap(); + let path = tmp_dir.into_path(); + let (tx, _rx) = crossbeam_channel::unbounded(); + + RuntimeManager::new(path, "http://localhost:5000".to_string(), tx) } #[derive(Clone)] @@ -337,6 +373,22 @@ mod tests { StubSecretGetter } + #[derive(Clone)] + struct StubDeploymentUpdater; + + #[async_trait] + impl DeploymentUpdater for StubDeploymentUpdater { + type Err = std::io::Error; + + async fn set_address(&self, _id: &Uuid, _address: &SocketAddr) -> Result<(), Self::Err> { + Ok(()) + } + + async fn set_is_next(&self, _id: &Uuid, _is_next: bool) -> Result<(), Self::Err> { + Ok(()) + } + } + // This test uses the kill signal to make sure a service does stop when asked to #[tokio::test] async fn can_be_killed() { @@ -345,14 +397,10 @@ mod tests { let (kill_send, kill_recv) = broadcast::channel(1); let (cleanup_send, cleanup_recv) = oneshot::channel(); - let handle_cleanup = |result: std::result::Result< - std::result::Result<(), shuttle_service::Error>, - JoinError, - >| { + let handle_cleanup = |result: std::result::Result, Status>| { assert!( - matches!(result, Err(ref join_error) if join_error.is_cancelled()), - "handle should have been cancelled: {:?}", - result + result.unwrap().into_inner().success, + "handle should have been cancelled", ); cleanup_send.send(()).unwrap(); }; @@ -362,7 +410,8 @@ mod tests { .handle( storage_manager, secret_getter, - get_runtime_client().await, + get_runtime_manager(), + StubDeploymentUpdater, kill_recv, kill_old_deployments(), handle_cleanup, @@ -389,16 +438,13 @@ mod tests { let (_kill_send, kill_recv) = broadcast::channel(1); let (cleanup_send, cleanup_recv) = oneshot::channel(); - let handle_cleanup = |result: std::result::Result< - std::result::Result<(), shuttle_service::Error>, - JoinError, - >| { - let result = result.unwrap(); - assert!( - result.is_ok(), - "did not expect error from self stopping service: {}", - result.unwrap_err() - ); + let handle_cleanup = |_result: std::result::Result, Status>| { + // let result = result.unwrap(); + // assert!( + // result.is_ok(), + // "did not expect error from self stopping service: {}", + // result.unwrap_err() + // ); cleanup_send.send(()).unwrap(); }; let secret_getter = get_secret_getter(); @@ -407,7 +453,8 @@ mod tests { .handle( storage_manager, secret_getter, - get_runtime_client().await, + get_runtime_manager(), + StubDeploymentUpdater, kill_recv, kill_old_deployments(), handle_cleanup, @@ -428,16 +475,13 @@ mod tests { let (_kill_send, kill_recv) = broadcast::channel(1); let (cleanup_send, cleanup_recv): (oneshot::Sender<()>, _) = oneshot::channel(); - let handle_cleanup = |result: std::result::Result< - std::result::Result<(), shuttle_service::Error>, - JoinError, - >| { - let result = result.unwrap(); - assert!( - matches!(result, Err(shuttle_service::Error::BindPanic(ref msg)) if msg == "panic in bind"), - "expected inner error from handle: {:?}", - result - ); + let handle_cleanup = |_result: std::result::Result, Status>| { + // let result = result.unwrap(); + // assert!( + // matches!(result, Err(shuttle_service::Error::BindPanic(ref msg)) if msg == "panic in bind"), + // "expected inner error from handle: {:?}", + // result + // ); cleanup_send.send(()).unwrap(); }; let secret_getter = get_secret_getter(); @@ -446,7 +490,8 @@ mod tests { .handle( storage_manager, secret_getter, - get_runtime_client().await, + get_runtime_manager(), + StubDeploymentUpdater, kill_recv, kill_old_deployments(), handle_cleanup, @@ -473,7 +518,8 @@ mod tests { .handle( storage_manager, secret_getter, - get_runtime_client().await, + get_runtime_manager(), + StubDeploymentUpdater, kill_recv, kill_old_deployments(), handle_cleanup, @@ -494,6 +540,7 @@ mod tests { service_name: "test".to_string(), service_id: Uuid::new_v4(), tracing_context: Default::default(), + is_next: false, }; let (_kill_send, kill_recv) = broadcast::channel(1); @@ -505,7 +552,8 @@ mod tests { .handle( storage_manager, secret_getter, - get_runtime_client().await, + get_runtime_manager(), + StubDeploymentUpdater, kill_recv, kill_old_deployments(), handle_cleanup, @@ -513,10 +561,7 @@ mod tests { .await; assert!( - matches!( - result, - Err(Error::Load(shuttle_service::loader::LoaderError::Load(_))) - ), + matches!(result, Err(Error::Load(_))), "expected missing 'so' error: {:?}", result ); @@ -554,6 +599,7 @@ mod tests { service_name: crate_name.to_string(), service_id: Uuid::new_v4(), tracing_context: Default::default(), + is_next: false, }, storage_manager, ) diff --git a/deployer/src/error.rs b/deployer/src/error.rs index 1adba5ae2..0f9ad03cc 100644 --- a/deployer/src/error.rs +++ b/deployer/src/error.rs @@ -2,8 +2,6 @@ use std::error::Error as StdError; use std::io; use thiserror::Error; -use shuttle_service::loader::LoaderError; - use cargo::util::errors::CargoTestError; use crate::deployment::gateway_client; @@ -15,7 +13,7 @@ pub enum Error { #[error("Build error: {0}")] Build(#[source] Box), #[error("Load error: {0}")] - Load(#[from] LoaderError), + Load(String), #[error("Prepare to run error: {0}")] PrepareRun(String), #[error("Run error: {0}")] @@ -30,6 +28,8 @@ pub enum Error { OldCleanup(#[source] Box), #[error("Gateway client error: {0}")] GatewayClient(#[from] gateway_client::Error), + #[error("Failed to get runtime: {0}")] + Runtime(#[source] anyhow::Error), } #[derive(Error, Debug)] diff --git a/deployer/src/handlers/mod.rs b/deployer/src/handlers/mod.rs index 848dbff7f..f3def2080 100644 --- a/deployer/src/handlers/mod.rs +++ b/deployer/src/handlers/mod.rs @@ -224,6 +224,7 @@ async fn post_service( state: State::Queued, last_update: Utc::now(), address: None, + is_next: false, }; let mut data = Vec::new(); diff --git a/deployer/src/lib.rs b/deployer/src/lib.rs index cb9b3bf0d..e5ae62450 100644 --- a/deployer/src/lib.rs +++ b/deployer/src/lib.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, net::SocketAddr}; +use std::{convert::Infallible, net::SocketAddr, sync::Arc}; pub use args::Args; pub use deployment::deploy_layer::DeployLayer; @@ -10,8 +10,8 @@ use hyper::{ }; pub use persistence::Persistence; use proxy::AddressGetter; -use shuttle_proto::runtime::runtime_client::RuntimeClient; -use tonic::transport::Channel; +pub use runtime_manager::RuntimeManager; +use tokio::sync::Mutex; use tracing::{error, info}; use crate::deployment::gateway_client::GatewayClient; @@ -22,14 +22,20 @@ mod error; mod handlers; mod persistence; mod proxy; +mod runtime_manager; -pub async fn start(persistence: Persistence, runtime_client: RuntimeClient, args: Args) { +pub async fn start( + persistence: Persistence, + runtime_manager: Arc>, + args: Args, +) { let deployment_manager = DeploymentManager::builder() .build_log_recorder(persistence.clone()) .secret_recorder(persistence.clone()) .active_deployment_getter(persistence.clone()) .artifacts_path(args.artifacts_path) - .runtime(runtime_client) + .runtime(runtime_manager) + .deployment_updater(persistence.clone()) .secret_getter(persistence.clone()) .queue_client(GatewayClient::new(args.gateway_uri)) .build(); @@ -44,6 +50,7 @@ pub async fn start(persistence: Persistence, runtime_client: RuntimeClient { error!("Proxy stopped.") }, - _ = start(persistence, runtime_client, args) => { + _ = start(persistence, runtime_manager, args) => { error!("Deployment service stopped.") }, - _ = runtime.wait() => { - error!("Legacy runtime stopped.") - }, - _ = logs_task => { - error!("Logs task stopped") - }, } exit(1); diff --git a/deployer/src/persistence/deployment.rs b/deployer/src/persistence/deployment.rs index 03d210066..b7e50f8b8 100644 --- a/deployer/src/persistence/deployment.rs +++ b/deployer/src/persistence/deployment.rs @@ -1,5 +1,6 @@ use std::{net::SocketAddr, str::FromStr}; +use async_trait::async_trait; use chrono::{DateTime, Utc}; use sqlx::{sqlite::SqliteRow, FromRow, Row}; use tracing::error; @@ -14,6 +15,7 @@ pub struct Deployment { pub state: State, pub last_update: DateTime, pub address: Option, + pub is_next: bool, } impl FromRow<'_, SqliteRow> for Deployment { @@ -36,6 +38,7 @@ impl FromRow<'_, SqliteRow> for Deployment { state: row.try_get("state")?, last_update: row.try_get("last_update")?, address, + is_next: row.try_get("is_next")?, }) } } @@ -51,12 +54,23 @@ impl From for shuttle_common::models::deployment::Response { } } +/// Update the details of a deployment +#[async_trait] +pub trait DeploymentUpdater: Clone + Send + Sync + 'static { + type Err: std::error::Error + Send; + + /// Set the address for a deployment + async fn set_address(&self, id: &Uuid, address: &SocketAddr) -> Result<(), Self::Err>; + + /// Set if a deployment is build on shuttle-next + async fn set_is_next(&self, id: &Uuid, is_next: bool) -> Result<(), Self::Err>; +} + #[derive(Debug, PartialEq, Eq)] pub struct DeploymentState { pub id: Uuid, pub state: State, pub last_update: DateTime, - pub address: Option, } #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] @@ -64,4 +78,5 @@ pub struct DeploymentRunnable { pub id: Uuid, pub service_name: String, pub service_id: Uuid, + pub is_next: bool, } diff --git a/deployer/src/persistence/mod.rs b/deployer/src/persistence/mod.rs index bc9ce7054..7ac9f3e2d 100644 --- a/deployer/src/persistence/mod.rs +++ b/deployer/src/persistence/mod.rs @@ -27,7 +27,7 @@ use tracing::{error, info, instrument, trace}; use uuid::Uuid; use self::deployment::DeploymentRunnable; -pub use self::deployment::{Deployment, DeploymentState}; +pub use self::deployment::{Deployment, DeploymentState, DeploymentUpdater}; pub use self::error::Error as PersistenceError; pub use self::log::{Level as LogLevel, Log}; pub use self::resource::{Resource, ResourceRecorder, Type as ResourceType}; @@ -158,13 +158,14 @@ impl Persistence { let deployment = deployment.into(); sqlx::query( - "INSERT INTO deployments (id, service_id, state, last_update, address) VALUES (?, ?, ?, ?, ?)", + "INSERT INTO deployments (id, service_id, state, last_update, address, is_next) VALUES (?, ?, ?, ?, ?, ?)", ) .bind(deployment.id) .bind(deployment.service_id) .bind(deployment.state) .bind(deployment.last_update) .bind(deployment.address.map(|socket| socket.to_string())) + .bind(deployment.is_next) .execute(&self.pool) .await .map(|_| ()) @@ -265,7 +266,7 @@ impl Persistence { pub async fn get_all_runnable_deployments(&self) -> Result> { sqlx::query_as( - r#"SELECT d.id, service_id, s.name AS service_name + r#"SELECT d.id, service_id, s.name AS service_name, d.is_next FROM deployments AS d JOIN services AS s ON s.id = d.service_id WHERE state = ? @@ -304,12 +305,9 @@ impl Persistence { async fn update_deployment(pool: &SqlitePool, state: impl Into) -> Result<()> { let state = state.into(); - // TODO: Handle moving to 'active_deployments' table for State::Running. - - sqlx::query("UPDATE deployments SET state = ?, last_update = ?, address = ? WHERE id = ?") + sqlx::query("UPDATE deployments SET state = ?, last_update = ? WHERE id = ?") .bind(state.state) .bind(state.last_update) - .bind(state.address.map(|socket| socket.to_string())) .bind(state.id) .execute(pool) .await @@ -442,6 +440,31 @@ impl AddressGetter for Persistence { } } +#[async_trait::async_trait] +impl DeploymentUpdater for Persistence { + type Err = Error; + + async fn set_address(&self, id: &Uuid, address: &SocketAddr) -> Result<()> { + sqlx::query("UPDATE deployments SET address = ? WHERE id = ?") + .bind(address.to_string()) + .bind(id) + .execute(&self.pool) + .await + .map(|_| ()) + .map_err(Error::from) + } + + async fn set_is_next(&self, id: &Uuid, is_next: bool) -> Result<()> { + sqlx::query("UPDATE deployments SET is_next = ? WHERE id = ?") + .bind(is_next) + .bind(id) + .execute(&self.pool) + .await + .map(|_| ()) + .map_err(Error::from) + } +} + #[async_trait::async_trait] impl ActiveDeploymentsGetter for Persistence { type Err = Error; @@ -493,7 +516,9 @@ mod tests { state: State::Queued, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 43, 33).unwrap(), address: None, + is_next: false, }; + let address = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 12345); p.insert_deployment(deployment.clone()).await.unwrap(); assert_eq!(p.get_deployment(&id).await.unwrap().unwrap(), deployment); @@ -504,13 +529,18 @@ mod tests { id, state: State::Built, last_update: Utc::now(), - address: None, }, ) .await .unwrap(); + + p.set_address(&id, &address).await.unwrap(); + p.set_is_next(&id, true).await.unwrap(); + let update = p.get_deployment(&id).await.unwrap().unwrap(); assert_eq!(update.state, State::Built); + assert_eq!(update.address, Some(address)); + assert!(update.is_next); assert_ne!( update.last_update, Utc.with_ymd_and_hms(2022, 4, 25, 4, 43, 33).unwrap() @@ -530,6 +560,7 @@ mod tests { state: State::Crashed, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 7, 29, 35).unwrap(), address: None, + is_next: false, }; let deployment_stopped = Deployment { id: Uuid::new_v4(), @@ -537,6 +568,7 @@ mod tests { state: State::Stopped, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 7, 49, 35).unwrap(), address: None, + is_next: false, }; let deployment_other = Deployment { id: Uuid::new_v4(), @@ -544,6 +576,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 7, 39, 39).unwrap(), address: None, + is_next: false, }; let deployment_running = Deployment { id: Uuid::new_v4(), @@ -551,6 +584,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 7, 48, 29).unwrap(), address: Some(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 9876)), + is_next: true, }; for deployment in [ @@ -590,6 +624,7 @@ mod tests { state: State::Crashed, last_update: Utc::now(), address: None, + is_next: false, }; let deployment_stopped = Deployment { id: Uuid::new_v4(), @@ -597,6 +632,7 @@ mod tests { state: State::Stopped, last_update: Utc::now(), address: None, + is_next: false, }; let deployment_running = Deployment { id: Uuid::new_v4(), @@ -604,6 +640,7 @@ mod tests { state: State::Running, last_update: Utc::now(), address: Some(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 9876)), + is_next: false, }; let deployment_queued = Deployment { id: queued_id, @@ -611,6 +648,7 @@ mod tests { state: State::Queued, last_update: Utc::now(), address: None, + is_next: false, }; let deployment_building = Deployment { id: building_id, @@ -618,6 +656,7 @@ mod tests { state: State::Building, last_update: Utc::now(), address: None, + is_next: false, }; let deployment_built = Deployment { id: built_id, @@ -625,6 +664,7 @@ mod tests { state: State::Built, last_update: Utc::now(), address: None, + is_next: true, }; let deployment_loading = Deployment { id: loading_id, @@ -632,6 +672,7 @@ mod tests { state: State::Loading, last_update: Utc::now(), address: None, + is_next: false, }; for deployment in [ @@ -690,6 +731,7 @@ mod tests { state: State::Built, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 29, 33).unwrap(), address: None, + is_next: false, }, Deployment { id: id_1, @@ -697,6 +739,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 29, 44).unwrap(), address: None, + is_next: false, }, Deployment { id: id_2, @@ -704,6 +747,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 33, 48).unwrap(), address: None, + is_next: true, }, Deployment { id: Uuid::new_v4(), @@ -711,6 +755,7 @@ mod tests { state: State::Crashed, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 38, 52).unwrap(), address: None, + is_next: true, }, Deployment { id: id_3, @@ -718,6 +763,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 42, 32).unwrap(), address: None, + is_next: false, }, ] { p.insert_deployment(deployment).await.unwrap(); @@ -731,16 +777,19 @@ mod tests { id: id_1, service_name: "foo".to_string(), service_id: foo_id, + is_next: false, }, DeploymentRunnable { id: id_2, service_name: "bar".to_string(), service_id: bar_id, + is_next: true, }, DeploymentRunnable { id: id_3, service_name: "foo".to_string(), service_id: foo_id, + is_next: false, }, ] ); @@ -759,6 +808,7 @@ mod tests { state: State::Running, last_update: Utc::now(), address: None, + is_next: true, }, Deployment { id: Uuid::new_v4(), @@ -766,6 +816,7 @@ mod tests { state: State::Running, last_update: Utc::now(), address: None, + is_next: false, }, ]; @@ -875,7 +926,6 @@ mod tests { target: "tests::log_recorder_event".to_string(), fields: json!({"message": "job queued"}), r#type: deploy_layer::LogType::Event, - address: None, }; p.record(event); @@ -910,6 +960,7 @@ mod tests { state: State::Queued, // Should be different from the state recorded below last_update: Utc.with_ymd_and_hms(2022, 4, 29, 2, 39, 39).unwrap(), address: None, + is_next: false, }) .await .unwrap(); @@ -923,7 +974,6 @@ mod tests { target: String::new(), fields: serde_json::Value::Null, r#type: deploy_layer::LogType::State, - address: Some("127.0.0.1:12345".to_string()), }; p.record(state); @@ -949,7 +999,8 @@ mod tests { service_id, state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 29, 2, 39, 59).unwrap(), - address: Some(SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 12345)), + address: None, + is_next: false, } ); } @@ -1126,6 +1177,7 @@ mod tests { state: State::Built, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 29, 33).unwrap(), address: None, + is_next: false, }, Deployment { id: Uuid::new_v4(), @@ -1133,6 +1185,7 @@ mod tests { state: State::Stopped, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 29, 44).unwrap(), address: None, + is_next: false, }, Deployment { id: id_1, @@ -1140,6 +1193,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 33, 48).unwrap(), address: None, + is_next: false, }, Deployment { id: Uuid::new_v4(), @@ -1147,6 +1201,7 @@ mod tests { state: State::Crashed, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 38, 52).unwrap(), address: None, + is_next: false, }, Deployment { id: id_2, @@ -1154,6 +1209,7 @@ mod tests { state: State::Running, last_update: Utc.with_ymd_and_hms(2022, 4, 25, 4, 42, 32).unwrap(), address: None, + is_next: true, }, ] { p.insert_deployment(deployment).await.unwrap(); diff --git a/deployer/src/runtime_manager.rs b/deployer/src/runtime_manager.rs new file mode 100644 index 000000000..089350c17 --- /dev/null +++ b/deployer/src/runtime_manager.rs @@ -0,0 +1,127 @@ +use std::{path::PathBuf, sync::Arc}; + +use anyhow::Context; +use shuttle_proto::runtime::{self, runtime_client::RuntimeClient, SubscribeLogsRequest}; +use tokio::{process, sync::Mutex}; +use tonic::transport::Channel; +use tracing::{info, instrument, trace}; + +use crate::deployment::deploy_layer; + +#[derive(Clone)] +pub struct RuntimeManager { + legacy: Option>, + legacy_process: Option>>, + next: Option>, + next_process: Option>>, + artifacts_path: PathBuf, + provisioner_address: String, + log_sender: crossbeam_channel::Sender, +} + +impl RuntimeManager { + pub fn new( + artifacts_path: PathBuf, + provisioner_address: String, + log_sender: crossbeam_channel::Sender, + ) -> Arc> { + Arc::new(Mutex::new(Self { + legacy: None, + legacy_process: None, + next: None, + next_process: None, + artifacts_path, + provisioner_address, + log_sender, + })) + } + + pub async fn get_runtime_client( + &mut self, + is_next: bool, + ) -> anyhow::Result<&mut RuntimeClient> { + if is_next { + Self::get_runtime_client_helper( + &mut self.next, + &mut self.next_process, + is_next, + 6002, + self.artifacts_path.clone(), + &self.provisioner_address, + self.log_sender.clone(), + ) + .await + } else { + Self::get_runtime_client_helper( + &mut self.legacy, + &mut self.legacy_process, + is_next, + 6001, + self.artifacts_path.clone(), + &self.provisioner_address, + self.log_sender.clone(), + ) + .await + } + } + + #[instrument(skip(runtime_option, process_option, log_sender))] + async fn get_runtime_client_helper<'a>( + runtime_option: &'a mut Option>, + process_option: &mut Option>>, + is_next: bool, + port: u16, + artifacts_path: PathBuf, + provisioner_address: &str, + log_sender: crossbeam_channel::Sender, + ) -> anyhow::Result<&'a mut RuntimeClient> { + if let Some(runtime_client) = runtime_option { + trace!("returning previous client"); + Ok(runtime_client) + } else { + trace!("making new client"); + let (process, runtime_client) = runtime::start( + is_next, + runtime::StorageManagerType::Artifacts(artifacts_path), + provisioner_address, + port, + ) + .await + .context("failed to start shuttle runtime")?; + + let sender = log_sender; + let mut stream = runtime_client + .clone() + .subscribe_logs(tonic::Request::new(SubscribeLogsRequest {})) + .await + .context("subscribing to runtime logs stream")? + .into_inner(); + + tokio::spawn(async move { + while let Ok(Some(log)) = stream.message().await { + sender.send(log.into()).expect("to send log to persistence"); + } + }); + + *runtime_option = Some(runtime_client); + *process_option = Some(Arc::new(std::sync::Mutex::new(process))); + + // Safe to unwrap as it was just set + Ok(runtime_option.as_mut().unwrap()) + } + } +} + +impl Drop for RuntimeManager { + fn drop(&mut self) { + info!("runtime manager shutting down"); + + if let Some(ref process) = self.legacy_process.take() { + let _ = process.lock().unwrap().start_kill(); + } + + if let Some(ref process) = self.next_process.take() { + let _ = process.lock().unwrap().start_kill(); + } + } +} diff --git a/proto/src/lib.rs b/proto/src/lib.rs index 887e54a81..122a2cb0c 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -241,6 +241,7 @@ pub mod runtime { wasm: bool, storage_manager_type: StorageManagerType, provisioner_address: &str, + port: u16, ) -> anyhow::Result<(process::Child, runtime_client::RuntimeClient)> { let runtime_flag = if wasm { "--axum" } else { "--legacy" }; @@ -254,6 +255,8 @@ pub mod runtime { let runtime = process::Command::new(runtime_executable) .args([ runtime_flag, + "--port", + &port.to_string(), "--provisioner-address", provisioner_address, "--storage-manager-type", @@ -269,7 +272,7 @@ pub mod runtime { tokio::time::sleep(Duration::from_secs(2)).await; info!("connecting runtime client"); - let conn = Endpoint::new("http://127.0.0.1:6001") + let conn = Endpoint::new(format!("http://127.0.0.1:{port}")) .context("creating runtime client endpoint")? .connect_timeout(Duration::from_secs(5)); diff --git a/runtime/src/args.rs b/runtime/src/args.rs index 016121d4b..57e78eced 100644 --- a/runtime/src/args.rs +++ b/runtime/src/args.rs @@ -5,6 +5,10 @@ use tonic::transport::Endpoint; #[derive(Parser, Debug)] pub struct Args { + /// Port to start runtime on + #[arg(long)] + pub port: u16, + /// Address to reach provisioner at #[arg(long, default_value = "http://localhost:5000")] pub provisioner_address: Endpoint, diff --git a/runtime/src/legacy/mod.rs b/runtime/src/legacy/mod.rs index c3d98c44f..9c2663f60 100644 --- a/runtime/src/legacy/mod.rs +++ b/runtime/src/legacy/mod.rs @@ -76,6 +76,11 @@ where trace!(path, "loading"); let so_path = PathBuf::from(path); + + if !so_path.exists() { + return Err(Status::not_found("'.so' to load does not exist")); + } + *self.so_path.lock().unwrap() = Some(so_path); *self.secrets.lock().unwrap() = Some(BTreeMap::from_iter(secrets.into_iter())); @@ -88,6 +93,8 @@ where &self, request: Request, ) -> Result, Status> { + trace!("legacy starting"); + let provisioner_client = ProvisionerClient::connect(self.provisioner_address.clone()) .await .expect("failed to connect to provisioner"); @@ -116,6 +123,8 @@ where .map_err(|err| Status::from_error(Box::new(err)))? .clone(); + trace!("prepare done"); + let StartRequest { deployment_id, service_name, @@ -134,6 +143,7 @@ where secrets, self.storage_manager.clone(), ); + trace!("got factory"); let logs_tx = self.logs_tx.lock().unwrap().clone(); diff --git a/runtime/src/main.rs b/runtime/src/main.rs index b7d9f3c3e..8a3a5f803 100644 --- a/runtime/src/main.rs +++ b/runtime/src/main.rs @@ -27,7 +27,7 @@ async fn main() { trace!(args = ?args, "parsed args"); - let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 6001); + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), args.port); let provisioner_address = args.provisioner_address; let mut server_builder = diff --git a/service/src/loader.rs b/service/src/loader.rs index 98f42db29..9f130716c 100644 --- a/service/src/loader.rs +++ b/service/src/loader.rs @@ -72,6 +72,8 @@ impl Loader { addr: SocketAddr, logger: logger::Logger, ) -> Result { + trace!("loading service"); + let mut bootstrapper = self.bootstrapper; AssertUnwindSafe(bootstrapper.bootstrap(factory, logger)) diff --git a/tmp/axum-wasm/Cargo.toml b/tmp/axum-wasm/Cargo.toml index 341c6524f..e9e38ca2c 100644 --- a/tmp/axum-wasm/Cargo.toml +++ b/tmp/axum-wasm/Cargo.toml @@ -9,3 +9,7 @@ crate-type = [ "cdylib" ] [dependencies] shuttle-next = "0.8.0" tracing = "0.1.37" + +[dev-dependencies] +tokio = { version = "1.22.0", features = ["macros", "rt-multi-thread"] } +hyper = "0.14.23" diff --git a/tmp/axum-wasm/src/lib.rs b/tmp/axum-wasm/src/lib.rs index a6af35583..14776b925 100644 --- a/tmp/axum-wasm/src/lib.rs +++ b/tmp/axum-wasm/src/lib.rs @@ -13,3 +13,27 @@ shuttle_next::app! { "Goodbye, World!" } } + +#[cfg(test)] +mod tests { + use crate::__app; + use http::Request; + use hyper::Method; + + #[tokio::test] + async fn hello() { + let request = Request::builder() + .uri("http://local.test/hello") + .method(Method::GET) + .body(axum::body::boxed(axum::body::Body::empty())) + .unwrap(); + + let response = __app(request).await; + + assert!(response.status().is_success()); + + let body = &hyper::body::to_bytes(response.into_body()).await.unwrap(); + + assert_eq!(body, "Hello, World!"); + } +}