diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index 69ca3cfdb..524bb42a2 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -190,8 +190,7 @@ pub mod tests { #[tokio::test] async fn api_create_get_delete_projects() -> anyhow::Result<()> { let world = World::new().await; - let service = - Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); let (sender, mut receiver) = channel::(256); tokio::spawn(async move { @@ -326,8 +325,7 @@ pub mod tests { #[tokio::test] async fn api_create_get_users() -> anyhow::Result<()> { let world = World::new().await; - let service = - Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); let (sender, mut receiver) = channel::(256); tokio::spawn(async move { @@ -416,8 +414,7 @@ pub mod tests { #[tokio::test(flavor = "multi_thread")] async fn status() { let world = World::new().await; - let service = - Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); let (sender, mut receiver) = channel::(1); let (ctl_send, ctl_recv) = oneshot::channel(); diff --git a/gateway/src/args.rs b/gateway/src/args.rs index 5f65a7480..ae0b1bced 100644 --- a/gateway/src/args.rs +++ b/gateway/src/args.rs @@ -19,6 +19,7 @@ pub struct Args { pub enum Commands { Start(StartArgs), Init(InitArgs), + Exec(ExecCmds), } #[derive(clap::Args, Debug, Clone)] @@ -29,6 +30,35 @@ pub struct StartArgs { /// Address to bind the user plane to #[arg(long, default_value = "127.0.0.1:8000")] pub user: SocketAddr, + #[command(flatten)] + pub context: ContextArgs, +} + +#[derive(clap::Args, Debug, Clone)] +pub struct InitArgs { + /// Name of initial account to create + #[arg(long)] + pub name: String, + /// Key to assign to initial account + #[arg(long)] + pub key: Option, +} + +#[derive(clap::Args, Debug, Clone)] +pub struct ExecCmds { + #[command(flatten)] + pub context: ContextArgs, + #[command(subcommand)] + pub command: ExecCmd, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum ExecCmd { + Revive, +} + +#[derive(clap::Args, Debug, Clone)] +pub struct ContextArgs { /// Default image to deploy user runtimes into #[arg(long, default_value = "public.ecr.aws/shuttle/deployer:latest")] pub image: String, @@ -40,23 +70,13 @@ pub struct StartArgs { /// the provisioner service #[arg(long, default_value = "provisioner")] pub provisioner_host: String, - /// The path to the docker daemon socket - #[arg(long, default_value = "/var/run/docker.sock")] - pub docker_host: String, /// The Docker Network name in which to deploy user runtimes #[arg(long, default_value = "shuttle_default")] pub network_name: String, /// FQDN where the proxy can be reached at #[arg(long)] pub proxy_fqdn: FQDN, -} - -#[derive(clap::Args, Debug, Clone)] -pub struct InitArgs { - /// Name of initial account to create - #[arg(long)] - pub name: String, - /// Key to assign to initial account - #[arg(long)] - pub key: Option, + /// The path to the docker daemon socket + #[arg(long, default_value = "/var/run/docker.sock")] + pub docker_host: String, } diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index d5d2980bd..6a99be184 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -287,7 +287,7 @@ pub mod tests { use tracing::info; use crate::api::make_api; - use crate::args::StartArgs; + use crate::args::{ContextArgs, StartArgs}; use crate::auth::User; use crate::proxy::make_proxy; use crate::service::{ContainerSettings, GatewayService, MIGRATIONS}; @@ -485,7 +485,6 @@ pub mod tests { args: StartArgs, hyper: HyperClient, pool: SqlitePool, - fqdn: String, } #[derive(Clone, Copy)] @@ -493,13 +492,11 @@ pub mod tests { pub docker: &'c Docker, pub container_settings: &'c ContainerSettings, pub hyper: &'c HyperClient, - pub fqdn: &'c str, } impl World { pub async fn new() -> Self { let docker = Docker::connect_with_local_defaults().unwrap(); - let fqdn = "test.shuttleapp.rs".to_string(); docker .list_images::<&str>(None) @@ -529,17 +526,19 @@ pub mod tests { let args = StartArgs { control, - docker_host, user, - image, - prefix, - provisioner_host, - network_name, - proxy_fqdn: FQDN::from_str(&fqdn).unwrap(), + context: ContextArgs { + docker_host, + image, + prefix, + provisioner_host, + network_name, + proxy_fqdn: FQDN::from_str("test.shuttleapp.rs").unwrap(), + }, }; - let settings = ContainerSettings::builder(&docker, fqdn.clone()) - .from_args(&args) + let settings = ContainerSettings::builder(&docker) + .from_args(&args.context) .await; let hyper = HyperClient::builder().build(HttpConnector::new()); @@ -553,12 +552,11 @@ pub mod tests { args, hyper, pool, - fqdn, } } - pub fn args(&self) -> StartArgs { - self.args.clone() + pub fn args(&self) -> ContextArgs { + self.args.context.clone() } pub fn pool(&self) -> SqlitePool { @@ -570,7 +568,11 @@ pub mod tests { } pub fn fqdn(&self) -> String { - self.fqdn.clone() + self.args() + .proxy_fqdn + .to_string() + .trim_end_matches('.') + .to_string() } } @@ -580,7 +582,6 @@ pub mod tests { docker: &self.docker, container_settings: &self.settings, hyper: &self.hyper, - fqdn: &self.fqdn, } } } @@ -598,8 +599,7 @@ pub mod tests { #[tokio::test] async fn end_to_end() { let world = World::new().await; - let service = - Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + let service = Arc::new(GatewayService::init(world.args(), world.pool()).await); let worker = Worker::new(Arc::clone(&service)); let (log_out, mut log_in) = channel(256); diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 5a01f5e48..24df027fc 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -1,12 +1,12 @@ use clap::Parser; use futures::prelude::*; -use shuttle_gateway::args::{Args, Commands, InitArgs}; +use shuttle_gateway::args::{Args, Commands, ExecCmd, ExecCmds, InitArgs}; use shuttle_gateway::auth::Key; use shuttle_gateway::proxy::make_proxy; use shuttle_gateway::service::{GatewayService, MIGRATIONS}; use shuttle_gateway::worker::{Work, Worker}; use shuttle_gateway::{api::make_api, args::StartArgs}; -use shuttle_gateway::{Refresh, Service}; +use shuttle_gateway::{project, Refresh, Service}; use sqlx::migrate::MigrateDatabase; use sqlx::{query, Sqlite, SqlitePool}; use std::io; @@ -55,16 +55,18 @@ async fn main() -> io::Result<()> { match args.command { Commands::Start(start_args) => start(db, start_args).await, Commands::Init(init_args) => init(db, init_args).await, + Commands::Exec(exec_cmd) => exec(db, exec_cmd).await, } } async fn start(db: SqlitePool, args: StartArgs) -> io::Result<()> { let fqdn = args + .context .proxy_fqdn .to_string() .trim_end_matches('.') .to_string(); - let gateway = Arc::new(GatewayService::init(args.clone(), fqdn.clone(), db).await); + let gateway = Arc::new(GatewayService::init(args.context.clone(), db).await); let worker = Worker::new(Arc::clone(&gateway)); @@ -146,3 +148,15 @@ async fn init(db: SqlitePool, args: InitArgs) -> io::Result<()> { println!("`{}` created as super user with key: {key}", args.name); Ok(()) } + +async fn exec(db: SqlitePool, exec_cmd: ExecCmds) -> io::Result<()> { + let gateway = GatewayService::init(exec_cmd.context.clone(), db).await; + + match exec_cmd.command { + ExecCmd::Revive => project::exec::revive(gateway) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?, + }; + + Ok(()) +} diff --git a/gateway/src/project.rs b/gateway/src/project.rs index 7f5d983ec..4e80925c3 100644 --- a/gateway/src/project.rs +++ b/gateway/src/project.rs @@ -703,6 +703,62 @@ impl<'c> State<'c> for ProjectError { } } +pub mod exec { + use bollard::service::ContainerState; + + use crate::{ + service::GatewayService, + worker::{do_work, Work}, + }; + + use super::*; + + pub async fn revive(gateway: GatewayService) -> Result<(), ProjectError> { + let mut mutations = Vec::new(); + + for Work { + project_name, + account_name, + work, + } in gateway + .iter_projects() + .await + .expect("could not list projects") + { + if let Project::Errored(ProjectError { ctx: Some(ctx), .. }) = work { + if let Some(container) = ctx.container() { + if let Ok(container) = gateway + .context() + .docker() + .inspect_container(safe_unwrap!(container.id), None) + .await + { + if let Some(ContainerState { + status: Some(ContainerStateStatusEnum::EXITED), + .. + }) = container.state + { + mutations.push(Work { + project_name, + account_name, + work: Project::Stopped(ProjectStopped { container }), + }); + } + } + } + } + } + + for work in mutations { + debug!(?work, "project will be revived"); + + do_work(work, &gateway).await; + } + + Ok(()) + } +} + #[cfg(test)] pub mod tests { diff --git a/gateway/src/service.rs b/gateway/src/service.rs index 2cae0e7bf..d8a84f712 100644 --- a/gateway/src/service.rs +++ b/gateway/src/service.rs @@ -20,7 +20,7 @@ use sqlx::types::Json as SqlxJson; use sqlx::{query, Error as SqlxError, Row}; use tracing::debug; -use crate::args::StartArgs; +use crate::args::ContextArgs; use crate::auth::{Key, User}; use crate::project::{self, Project}; use crate::worker::Work; @@ -43,33 +43,35 @@ pub struct ContainerSettingsBuilder<'d> { image: Option, provisioner: Option, network_name: Option, - fqdn: String, + fqdn: Option, } impl<'d> ContainerSettingsBuilder<'d> { - pub fn new(docker: &'d Docker, fqdn: String) -> Self { + pub fn new(docker: &'d Docker) -> Self { Self { docker, prefix: None, image: None, provisioner: None, network_name: None, - fqdn, + fqdn: None, } } - pub async fn from_args(self, args: &StartArgs) -> ContainerSettings { - let StartArgs { + pub async fn from_args(self, args: &ContextArgs) -> ContainerSettings { + let ContextArgs { prefix, network_name, provisioner_host, image, + proxy_fqdn, .. } = args; self.prefix(prefix) .image(image) .provisioner_host(provisioner_host) .network_name(network_name) + .fqdn(proxy_fqdn) .build() .await } @@ -94,6 +96,11 @@ impl<'d> ContainerSettingsBuilder<'d> { self } + pub fn fqdn(mut self, fqdn: S) -> Self { + self.fqdn = Some(fqdn.to_string().trim_end_matches('.').to_string()); + self + } + /// Resolves the Docker network ID for the given network name. /// /// # Panics @@ -125,7 +132,7 @@ impl<'d> ContainerSettingsBuilder<'d> { let network_name = self.network_name.take().unwrap(); let network_id = self.resolve_network_id(&network_name).await; - let fqdn = self.fqdn; + let fqdn = self.fqdn.take().unwrap(); ContainerSettings { prefix, @@ -148,8 +155,8 @@ pub struct ContainerSettings { } impl ContainerSettings { - pub fn builder(docker: &Docker, fqdn: String) -> ContainerSettingsBuilder { - ContainerSettingsBuilder::new(docker, fqdn) + pub fn builder(docker: &Docker) -> ContainerSettingsBuilder { + ContainerSettingsBuilder::new(docker) } } @@ -181,12 +188,10 @@ impl GatewayService { /// /// * `args` - The [`Args`] with which the service was /// started. Will be passed as [`Context`] to workers and state. - pub async fn init(args: StartArgs, fqdn: String, db: SqlitePool) -> Self { + pub async fn init(args: ContextArgs, db: SqlitePool) -> Self { let docker = Docker::connect_with_unix(&args.docker_host, 60, API_DEFAULT_VERSION).unwrap(); - let container_settings = ContainerSettings::builder(&docker, fqdn) - .from_args(&args) - .await; + let container_settings = ContainerSettings::builder(&docker).from_args(&args).await; let provider = GatewayContextProvider::new(docker, container_settings); @@ -439,11 +444,33 @@ impl GatewayService { }) } - fn context(&self) -> GatewayContext { + pub fn context(&self) -> GatewayContext { self.provider.context() } } +#[async_trait] +impl<'c> Service<'c> for GatewayService { + type Context = GatewayContext<'c>; + + type State = Work; + + type Error = Error; + + fn context(&'c self) -> Self::Context { + GatewayService::context(self) + } + + async fn update( + &self, + Work { + project_name, work, .. + }: &Self::State, + ) -> Result<(), Self::Error> { + self.update_project(project_name, work).await + } +} + #[async_trait] impl<'c> Service<'c> for Arc { type Context = GatewayContext<'c>; @@ -492,7 +519,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_user() -> anyhow::Result<()> { let world = World::new().await; - let svc = GatewayService::init(world.args(), world.fqdn(), world.pool()).await; + let svc = GatewayService::init(world.args(), world.pool()).await; let account_name: AccountName = "test_user_123".parse()?; @@ -543,7 +570,7 @@ pub mod tests { #[tokio::test] async fn service_create_find_delete_project() -> anyhow::Result<()> { let world = World::new().await; - let svc = Arc::new(GatewayService::init(world.args(), world.fqdn(), world.pool()).await); + let svc = Arc::new(GatewayService::init(world.args(), world.pool()).await); let neo: AccountName = "neo".parse().unwrap(); let matrix: ProjectName = "matrix".parse().unwrap(); diff --git a/gateway/src/worker.rs b/gateway/src/worker.rs index ee71dfd8a..d371bdb66 100644 --- a/gateway/src/worker.rs +++ b/gateway/src/worker.rs @@ -115,33 +115,45 @@ where let _ = self.send.take().unwrap(); debug!("starting worker"); - while let Some(mut work) = self.recv.recv().await { + while let Some(work) = self.recv.recv().await { debug!(?work, "received work"); - loop { - work = { - let context = self.service.context(); - - // Safety: EndState's transitions are Infallible - work.next(&context).await.unwrap() - }; - - match self.service.update(&work).await { - Ok(_) => {} - Err(err) => info!("failed to update a state: {}\nstate: {:?}", err, work), - }; - - if work.is_done() { - break; - } else { - debug!(?work, "work not done yet"); - } - } + do_work(work, &self.service).await; } Ok(self) } } +pub async fn do_work< + 'c, + E: std::fmt::Display, + S: Service<'c, State = W, Error = E>, + W: EndState<'c> + Debug, +>( + mut work: W, + service: &'c S, +) { + loop { + work = { + let context = service.context(); + + // Safety: EndState's transitions are Infallible + work.next(&context).await.unwrap() + }; + + match service.update(&work).await { + Ok(_) => {} + Err(err) => info!("failed to update a state: {}\nstate: {:?}", err, work), + }; + + if work.is_done() { + break; + } else { + debug!(?work, "work not done yet"); + } + } +} + #[cfg(test)] pub mod tests { use std::convert::Infallible; diff --git a/gateway/tests/hello_world.crate b/gateway/tests/hello_world.crate index d4f72b6be..038d4d03e 100644 Binary files a/gateway/tests/hello_world.crate and b/gateway/tests/hello_world.crate differ