diff --git a/cas/grpc_service/BUILD b/cas/grpc_service/BUILD index c2ec8cca1..f9ff6257c 100644 --- a/cas/grpc_service/BUILD +++ b/cas/grpc_service/BUILD @@ -81,6 +81,7 @@ rust_library( "//third_party:futures", "//third_party:rand", "//third_party:stdext", + "//third_party:tokio", "//third_party:tokio_stream", "//third_party:tonic", "//util:common", @@ -126,6 +127,7 @@ rust_test( deps = [ "//cas/scheduler", "//cas/scheduler:platform_property_manager", + "//cas/scheduler:worker", "//proto", "//third_party:pretty_assertions", "//third_party:tokio", diff --git a/cas/grpc_service/execution_server.rs b/cas/grpc_service/execution_server.rs index 7827321e5..a56a54be0 100644 --- a/cas/grpc_service/execution_server.rs +++ b/cas/grpc_service/execution_server.rs @@ -3,10 +3,11 @@ use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use std::time::{Duration, Instant, SystemTime}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures::{Stream, StreamExt}; use rand::{thread_rng, Rng}; +use tokio::time::interval; use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; @@ -27,8 +28,11 @@ use store::{Store, StoreManager}; /// Default priority remote execution jobs will get when not provided. const DEFAULT_EXECUTION_PRIORITY: i64 = 0; +/// Default timeout for workers in seconds. +const DEFAULT_WORKER_TIMEOUT_S: u64 = 5; + struct InstanceInfo { - scheduler: Scheduler, + scheduler: Arc, cas_store: Arc, platform_property_manager: PlatformPropertyManager, } @@ -120,14 +124,38 @@ impl ExecutionServer { .clone() .unwrap_or(HashMap::new()), ); + let mut worker_timeout_s = exec_cfg.worker_timeout_s; + if worker_timeout_s == 0 { + worker_timeout_s = DEFAULT_WORKER_TIMEOUT_S; + } + let scheduler = Arc::new(Scheduler::new(worker_timeout_s)); + let weak_scheduler = Arc::downgrade(&scheduler); instance_infos.insert( instance_name.to_string(), InstanceInfo { - scheduler: Scheduler::new(), + scheduler, cas_store, platform_property_manager, }, ); + tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(1)); + loop { + ticker.tick().await; + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Error: system time is now behind unix epoch"); + match weak_scheduler.upgrade() { + Some(scheduler) => { + if let Err(e) = scheduler.remove_timedout_workers(timestamp.as_secs()).await { + log::error!("Error while running remove_timedout_workers : {:?}", e); + } + } + // If we fail to upgrade, our service is probably destroyed, so return. + None => return, + } + } + }); } Ok(Self { instance_infos }) } diff --git a/cas/grpc_service/tests/worker_api_server_test.rs b/cas/grpc_service/tests/worker_api_server_test.rs index fc407fb1d..770fa8a1e 100644 --- a/cas/grpc_service/tests/worker_api_server_test.rs +++ b/cas/grpc_service/tests/worker_api_server_test.rs @@ -1,56 +1,200 @@ // Copyright 2022 Nathan (Blaise) Bruer. All rights reserved. use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; +use std::time::Duration; use tokio_stream::StreamExt; use tonic::Request; -use error::ResultExt; +use error::{Error, ResultExt}; use platform_property_manager::PlatformPropertyManager; use proto::com::github::allada::turbo_cache::remote_execution::{ - update_for_worker, worker_api_server::WorkerApi, SupportedProperties, + update_for_worker, worker_api_server::WorkerApi, KeepAliveRequest, SupportedProperties, }; use scheduler::Scheduler; -use worker_api_server::WorkerApiServer; +use worker::WorkerId; +use worker_api_server::{ConnectWorkerStream, NowFn, WorkerApiServer}; + +const BASE_NOW_S: u64 = 10; +const BASE_WORKER_TIMEOUT_S: u64 = 100; + +struct TestContext { + scheduler: Arc, + worker_api_server: WorkerApiServer, + connection_worker_stream: ConnectWorkerStream, + worker_id: WorkerId, +} + +fn static_now_fn() -> Result { + Ok(Duration::from_secs(BASE_NOW_S)) +} + +async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result { + let platform_properties = HashMap::new(); + let scheduler = Arc::new(Scheduler::new(worker_timeout)); + let worker_api_server = WorkerApiServer::new_with_now_fn( + Arc::new(PlatformPropertyManager::new(platform_properties)), + scheduler.clone(), + now_fn, + ); + + let supported_properties = SupportedProperties::default(); + let mut connection_worker_stream = worker_api_server + .connect_worker(Request::new(supported_properties)) + .await? + .into_inner(); + + let maybe_first_message = connection_worker_stream.next().await; + assert!(maybe_first_message.is_some(), "Expected first message from stream"); + let first_update = maybe_first_message + .unwrap() + .err_tip(|| "Expected success result")? + .update + .err_tip(|| "Expected update field to be populated")?; + let worker_id = match first_update { + update_for_worker::Update::ConnectionResult(connection_result) => connection_result.worker_id, + other => unreachable!("Expected ConnectionResult, got {:?}", other), + }; + + const UUID_SIZE: usize = 36; + assert_eq!(worker_id.len(), UUID_SIZE, "Worker ID should be 36 characters"); + + Ok(TestContext { + scheduler, + worker_api_server, + connection_worker_stream, + worker_id: worker_id.try_into()?, + }) +} #[cfg(test)] pub mod connect_worker_tests { use super::*; - use pretty_assertions::assert_eq; // Must be declared in every module. #[tokio::test] pub async fn connect_worker_adds_worker_to_scheduler_test() -> Result<(), Box> { - let platform_properties = HashMap::new(); - let scheduler = Arc::new(Scheduler::new()); - let worker_api_server = WorkerApiServer::new( - Arc::new(PlatformPropertyManager::new(platform_properties)), - scheduler.clone(), - ); - - let supported_properties = SupportedProperties::default(); - let mut update_for_worker_stream = worker_api_server - .connect_worker(Request::new(supported_properties)) - .await? - .into_inner(); - - let maybe_first_message = update_for_worker_stream.next().await; - assert!(maybe_first_message.is_some(), "Expected first message from stream"); - let first_update = maybe_first_message - .unwrap() - .err_tip(|| "Expected success result")? - .update - .err_tip(|| "Expected update field to be populated")?; - let worker_id = match first_update { - update_for_worker::Update::ConnectionResult(connection_result) => connection_result.worker_id, - other => unreachable!("Expected ConnectionResult, got {:?}", other), + let test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; + + let worker_exists = test_context + .scheduler + .contains_worker_for_test(&test_context.worker_id) + .await; + assert!(worker_exists, "Expected worker to exist in worker map"); + + Ok(()) + } +} + +#[cfg(test)] +pub mod keep_alive_tests { + use super::*; + use pretty_assertions::assert_eq; // Must be declared in every module. + + #[tokio::test] + pub async fn server_times_out_workers_test() -> Result<(), Box> { + let test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; + + let mut now_timestamp = BASE_NOW_S; + { + // Now change time to 1 second before timeout and ensure the worker is still in the pool. + now_timestamp += BASE_WORKER_TIMEOUT_S - 1; + test_context.scheduler.remove_timedout_workers(now_timestamp).await?; + let worker_exists = test_context + .scheduler + .contains_worker_for_test(&test_context.worker_id) + .await; + assert!(worker_exists, "Expected worker to exist in worker map"); + } + { + // Now add 1 second and our worker should have been evicted due to timeout. + now_timestamp += 1; + test_context.scheduler.remove_timedout_workers(now_timestamp).await?; + let worker_exists = test_context + .scheduler + .contains_worker_for_test(&test_context.worker_id) + .await; + assert!(!worker_exists, "Expected worker to not exist in map"); + } + + Ok(()) + } + + #[tokio::test] + pub async fn server_does_not_timeout_if_keep_alive_test() -> Result<(), Box> { + let now_timestamp = Arc::new(Mutex::new(BASE_NOW_S)); + let now_timestamp_clone = now_timestamp.clone(); + let add_and_return_timestamp = move |add_amount: u64| -> u64 { + let mut locked_now_timestamp = now_timestamp.lock().unwrap(); + *locked_now_timestamp += add_amount; + *locked_now_timestamp }; - const UUID_SIZE: usize = 36; - assert_eq!(worker_id.len(), UUID_SIZE, "Worker ID should be 36 characters"); + let test_context = setup_api_server( + BASE_WORKER_TIMEOUT_S, + Box::new(move || Ok(Duration::from_secs(*now_timestamp_clone.lock().unwrap()))), + ) + .await?; + { + // Now change time to 1 second before timeout and ensure the worker is still in the pool. + let timestamp = add_and_return_timestamp(BASE_WORKER_TIMEOUT_S - 1); + test_context.scheduler.remove_timedout_workers(timestamp).await?; + let worker_exists = test_context + .scheduler + .contains_worker_for_test(&test_context.worker_id) + .await; + assert!(worker_exists, "Expected worker to exist in worker map"); + } + { + // Now send keep alive. + test_context + .worker_api_server + .keep_alive(Request::new(KeepAliveRequest { + worker_id: test_context.worker_id.to_string(), + })) + .await + .err_tip(|| "Error sending keep alive")?; + } + { + // Now add 1 second and our worker should still exist in our map. + let timestamp = add_and_return_timestamp(1); + test_context.scheduler.remove_timedout_workers(timestamp).await?; + let worker_exists = test_context + .scheduler + .contains_worker_for_test(&test_context.worker_id) + .await; + assert!(worker_exists, "Expected worker to exist in map"); + } - let worker_exists = scheduler.contains_worker_for_test(&worker_id.try_into()?).await; - assert!(worker_exists, "Expected worker to exist in worker map"); + Ok(()) + } + + #[tokio::test] + pub async fn worker_receives_keep_alive_request_test() -> Result<(), Box> { + let mut test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; + + // Send keep alive to client. + test_context + .scheduler + .send_keep_alive_to_worker_for_test(&test_context.worker_id) + .await + .err_tip(|| "Could not send keep alive to worker")?; + + { + // Read stream and ensure it was a keep alive message. + let maybe_message = test_context.connection_worker_stream.next().await; + assert!(maybe_message.is_some(), "Expected next message in stream to exist"); + let update_message = maybe_message + .unwrap() + .err_tip(|| "Expected success result")? + .update + .err_tip(|| "Expected update field to be populated")?; + assert_eq!( + update_message, + update_for_worker::Update::KeepAlive(()), + "Expected KeepAlive message" + ); + } Ok(()) } diff --git a/cas/grpc_service/worker_api_server.rs b/cas/grpc_service/worker_api_server.rs index 454872e6b..99adce816 100644 --- a/cas/grpc_service/worker_api_server.rs +++ b/cas/grpc_service/worker_api_server.rs @@ -2,7 +2,7 @@ use std::pin::Pin; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use futures::{stream::unfold, Stream}; use tokio::sync::mpsc; @@ -10,26 +10,49 @@ use tonic::{Request, Response, Status}; use uuid::Uuid; use common::log; -use error::{Error, ResultExt}; +use error::{make_err, Code, Error, ResultExt}; use platform_property_manager::{PlatformProperties, PlatformPropertyManager}; use proto::com::github::allada::turbo_cache::remote_execution::{ - worker_api_server::WorkerApi, ExecuteResult, SupportedProperties, UpdateForWorker, + worker_api_server::WorkerApi, ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, + UpdateForWorker, }; use scheduler::Scheduler; use worker::{Worker, WorkerId}; -type ConnectWorkerStream = Pin> + Send + Sync + 'static>>; +pub type ConnectWorkerStream = Pin> + Send + Sync + 'static>>; + +pub type NowFn = Box Result + Send + Sync>; pub struct WorkerApiServer { platform_property_manager: Arc, scheduler: Arc, + now_fn: NowFn, } impl WorkerApiServer { pub fn new(platform_property_manager: Arc, scheduler: Arc) -> Self { + Self::new_with_now_fn( + platform_property_manager, + scheduler, + Box::new(move || { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|_| make_err!(Code::Internal, "System time is now behind unix epoch")) + }), + ) + } + + /// Same as new(), but you can pass a custom `now_fn`, that returns a Duration since UNIX_EPOCH + /// representing the current time. Used mostly in unit tests. + pub fn new_with_now_fn( + platform_property_manager: Arc, + scheduler: Arc, + now_fn: NowFn, + ) -> Self { Self { platform_property_manager, scheduler, + now_fn, } } @@ -57,7 +80,7 @@ impl WorkerApiServer { // Now register the worker with the scheduler. let worker_id = { let worker_id = Uuid::new_v4().as_u128(); - let worker = Worker::new(WorkerId(worker_id), platform_properties, tx); + let worker = Worker::new(WorkerId(worker_id), platform_properties, tx, (self.now_fn)()?.as_secs()); self.scheduler .add_worker(worker) .await @@ -81,6 +104,15 @@ impl WorkerApiServer { }, )))) } + + async fn inner_keep_alive(&self, keep_alive_request: KeepAliveRequest) -> Result, Error> { + let worker_id: WorkerId = keep_alive_request.worker_id.try_into()?; + self.scheduler + .worker_keep_alive_received(&worker_id, (self.now_fn)()?.as_secs()) + .await + .err_tip(|| "Could not process keep_alive from worker in inner_keep_alive()")?; + Ok(Response::new(())) + } } #[tonic::async_trait] @@ -103,11 +135,21 @@ impl WorkerApi for WorkerApiServer { return resp.map_err(|e| e.into()); } - async fn keep_alive(&self, _grpc_request: Request<()>) -> Result, Status> { - unimplemented!(); + async fn keep_alive(&self, grpc_request: Request) -> Result, Status> { + let now = Instant::now(); + log::info!("\x1b[0;31mkeep_alive Req\x1b[0m: {:?}", grpc_request.get_ref()); + let keep_alive_request = grpc_request.into_inner(); + let resp = self.inner_keep_alive(keep_alive_request).await; + let d = now.elapsed().as_secs_f32(); + if let Err(err) = resp.as_ref() { + log::error!("\x1b[0;31mkeep_alive Resp\x1b[0m: {} {:?}", d, err); + } else { + log::info!("\x1b[0;31mkeep_alive Resp\x1b[0m: {}", d); + } + return resp.map_err(|e| e.into()); } - async fn going_away(&self, _grpc_request: Request<()>) -> Result, Status> { + async fn going_away(&self, _grpc_request: Request) -> Result, Status> { unimplemented!(); } diff --git a/cas/scheduler/BUILD b/cas/scheduler/BUILD index 20cadcf7f..cd1d6ec2e 100644 --- a/cas/scheduler/BUILD +++ b/cas/scheduler/BUILD @@ -17,6 +17,7 @@ rust_library( srcs = ["scheduler.rs"], deps = [ "//third_party:fast_async_mutex", + "//third_party:lru", "//third_party:rand", "//third_party:tokio", "//util:common", diff --git a/cas/scheduler/scheduler.rs b/cas/scheduler/scheduler.rs index e10bf7b85..ab521ea2e 100644 --- a/cas/scheduler/scheduler.rs +++ b/cas/scheduler/scheduler.rs @@ -5,13 +5,14 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use fast_async_mutex::mutex::Mutex; +use lru::LruCache; use rand::{thread_rng, Rng}; use tokio::sync::watch; use action_messages::{ActionInfo, ActionStage, ActionState}; use common::log; -use error::{Error, ResultExt}; -use worker::{Worker, WorkerId, WorkerUpdate}; +use error::{error_if, make_input_err, Error, ResultExt}; +use worker::{Worker, WorkerId, WorkerTimestamp, WorkerUpdate}; /// An action that is being awaited on and last known state. struct AwaitedAction { @@ -27,33 +28,48 @@ struct RunningAction { } struct Workers { - workers: HashMap, + workers: LruCache, } impl Workers { fn new() -> Self { Self { - workers: HashMap::new(), + workers: LruCache::unbounded(), } } + /// Refreshes the lifetime of the worker with the given timestamp. + fn refresh_lifetime(&mut self, worker_id: &WorkerId, timestamp: WorkerTimestamp) -> Result<(), Error> { + let worker = self + .workers + .get_mut(worker_id) + .ok_or_else(|| make_input_err!("Worker not found in worker map in refresh_lifetime() {}", worker_id))?; + error_if!( + worker.last_update_timestamp > timestamp, + "Worker already had a timestamp of {}, but tried to update it with {}", + worker.last_update_timestamp, + timestamp + ); + worker.last_update_timestamp = timestamp; + Ok(()) + } + /// Adds a worker to the pool. /// Note: This function will not do any task matching. fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { let worker_id = worker.id; - self.workers.insert(worker_id, worker); + self.workers.put(worker_id, worker); // Worker is not cloneable, and we do not want to send the initial connection results until // we have added it to the map, or we might get some strange race conditions due to the way // the multi-threaded runtime works. - let worker = self.workers.get_mut(&worker_id).unwrap(); + let worker = self.workers.peek_mut(&worker_id).unwrap(); let res = worker .send_initial_connection_result() .err_tip(|| "Failed to send initial connection result to worker"); if let Err(e) = &res { - self.remove_worker(&worker_id); log::error!( - "Worker connection appears to have been closed while adding to pool. Removing from queue : {:?}", + "Worker connection appears to have been closed while adding to pool : {:?}", e ); } @@ -63,8 +79,8 @@ impl Workers { /// Removes worker from pool. /// Note: The caller is responsible for any rescheduling of any tasks that might be /// running. - fn remove_worker(&mut self, worker_id: &WorkerId) { - self.workers.remove(worker_id); + fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { + self.workers.pop(worker_id) } /// Attempts to find a worker that is capable of running this action. @@ -74,13 +90,19 @@ impl Workers { fn find_worker_for_action_mut<'a>(&'a mut self, awaited_action: &AwaitedAction) -> Option<&'a mut Worker> { assert!(matches!(awaited_action.current_state.stage, ActionStage::Queued)); let action_properties = &awaited_action.action_info.platform_properties; - return self - .workers - .values_mut() - .find(|w| action_properties.is_satisfied_by(&w.platform_properties)); + return self.workers.iter_mut().find_map(|(_, w)| { + if action_properties.is_satisfied_by(&w.platform_properties) { + Some(w) + } else { + None + } + }); } } +/// Simple helper type to help with self-documentation. +type ShouldRunAgain = bool; + struct SchedulerImpl { // We cannot use the special hash function we use for ActionInfo with BTreeMap because // we need to be able to get an exact match when we look for `ActionInfo` structs that @@ -93,6 +115,8 @@ struct SchedulerImpl { queued_actions: BTreeMap, AwaitedAction>, workers: Workers, active_actions: HashMap, RunningAction>, + /// Timeout of how long to evict workers if no response in this given amount of time in seconds. + worker_timeout_s: u64, } impl SchedulerImpl { @@ -164,10 +188,49 @@ impl SchedulerImpl { Ok(rx) } + /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. + /// Note: This will not call .do_try_match(). + fn immediate_evict_worker(&mut self, worker_id: &WorkerId) { + if let Some(mut worker) = self.workers.remove_worker(&worker_id) { + // We don't care if we fail to send message to worker, this is only a best attempt. + let _ = worker.notify_update(WorkerUpdate::Disconnect); + if let Some(action_info) = worker.running_action_info { + match self.active_actions.remove(&action_info) { + Some(running_action) => { + let mut awaited_action = running_action.action; + Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Queued; + let send_result = awaited_action.notify_channel.send(awaited_action.current_state.clone()); + self.queued_actions_set.insert(action_info.clone()); + self.queued_actions.insert(action_info.clone(), awaited_action); + if send_result.is_err() { + // Don't remove this task, instead we keep them around for a bit just in case + // the client disconnected and will reconnect and ask for same job to be executed + // again. + log::warn!( + "Action {} has no more listeners during evict_worker()", + action_info.digest.str() + ); + } + } + None => { + log::error!("Worker stated it was running an action, but it was not in the active_actions : Worker: {:?}, ActionInfo: {:?}", worker.id, action_info); + } + } + } + } + } + + /// Wrapper to keep running in the event we could not complete all scheduling in one iteration. + fn do_try_match(&mut self) { + // Run do_try_match until it doesn't need to run again. + while self.inner_do_try_match() {} + } + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map // of capabilities of each worker and then try and match the actions to the worker using // the map lookup (ie. map reduce). - fn do_try_match(&mut self) { + fn inner_do_try_match(&mut self) -> ShouldRunAgain { + let mut should_run_again = false; // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in // the way. We need to conditionally remove items from the `queued_action`. Rust is working // to add `drain_filter`, which would in theory solve this problem, but because we need @@ -191,7 +254,8 @@ impl SchedulerImpl { let notify_worker_result = worker.notify_update(WorkerUpdate::RunAction(action_info.clone())); if notify_worker_result.is_err() { // Remove worker, as it is no longer receiving messages and let it try to find another worker. - self.workers.remove_worker(&worker_id); + self.immediate_evict_worker(&worker_id); + should_run_again = true; continue; } @@ -218,6 +282,7 @@ impl SchedulerImpl { worker_received_msg = true; } } + should_run_again } } @@ -225,25 +290,34 @@ impl SchedulerImpl { /// the worker nodes. All state on how the workers and actions are interacting /// should be held in this struct. pub struct Scheduler { - inner: Arc>, + inner: Mutex, } impl Scheduler { - pub fn new() -> Self { + pub fn new(worker_timeout_s: u64) -> Self { Self { - inner: Arc::new(Mutex::new(SchedulerImpl { + inner: Mutex::new(SchedulerImpl { queued_actions_set: HashSet::new(), queued_actions: BTreeMap::new(), workers: Workers::new(), active_actions: HashMap::new(), - })), + worker_timeout_s, + }), } } /// Adds a worker to the scheduler and begin using it to execute actions (when able). pub async fn add_worker(&self, worker: Worker) -> Result<(), Error> { + let worker_id = worker.id.clone(); let mut inner = self.inner.lock().await; - inner.workers.add_worker(worker)?; + let res = inner + .workers + .add_worker(worker) + .err_tip(|| "Error while adding worker, removing from pool"); + if res.is_err() { + inner.immediate_evict_worker(&worker_id); + return res; + } inner.do_try_match(); Ok(()) } @@ -257,6 +331,55 @@ impl Scheduler { /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { let inner = self.inner.lock().await; - inner.workers.workers.contains_key(worker_id) + inner.workers.workers.contains(worker_id) + } + + /// Event for when the keep alive message was received from the worker. + pub async fn worker_keep_alive_received( + &self, + worker_id: &WorkerId, + timestamp: WorkerTimestamp, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .workers + .refresh_lifetime(worker_id, timestamp) + .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") + } + + pub async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire + // map most of the time. + let worker_ids_to_remove: Vec = inner + .workers + .workers + .iter() + .rev() + .map_while(|(worker_id, worker)| { + if worker.last_update_timestamp <= now_timestamp - inner.worker_timeout_s { + Some(*worker_id) + } else { + None + } + }) + .collect(); + for worker_id in worker_ids_to_remove { + log::warn!("Worker {} timed out, removing from pool", worker_id); + inner.immediate_evict_worker(&worker_id); + } + inner.do_try_match(); + Ok(()) + } + + /// A unit test function used to send the keep alive message to the worker from the server. + pub async fn send_keep_alive_to_worker_for_test(&self, worker_id: &WorkerId) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + let worker = inner + .workers + .workers + .get_mut(worker_id) + .ok_or_else(|| make_input_err!("WorkerId '{}' does not exist in workers map", worker_id))?; + worker.keep_alive() } } diff --git a/cas/scheduler/tests/scheduler_test.rs b/cas/scheduler/tests/scheduler_test.rs index e0a6b380f..98b922423 100644 --- a/cas/scheduler/tests/scheduler_test.rs +++ b/cas/scheduler/tests/scheduler_test.rs @@ -47,21 +47,24 @@ async fn verify_initial_connection_message(worker_id: WorkerId, rx: &mut mpsc::U } #[cfg(test)] -mod buf_channel_tests { +mod scheduler_tests { use super::*; use pretty_assertions::assert_eq; // Must be declared in every module. + const NOW_TIME: u64 = 10000; + const BASE_WORKER_TIMEOUT_S: u64 = 100; + #[tokio::test] async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { const WORKER_ID: WorkerId = WorkerId(0x123456789111); - let scheduler = Scheduler::new(); + let scheduler = Scheduler::new(BASE_WORKER_TIMEOUT_S); let action_digest = DigestInfo::new([99u8; 32], 512); let mut rx_from_worker = { let (tx, rx) = mpsc::unbounded_channel(); scheduler - .add_worker(Worker::new(WORKER_ID, PlatformProperties::default(), tx)) + .add_worker(Worker::new(WORKER_ID, PlatformProperties::default(), tx, NOW_TIME)) .await .err_tip(|| "Failed to add worker")?; rx @@ -106,7 +109,7 @@ mod buf_channel_tests { #[tokio::test] async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), Error> { - let scheduler = Scheduler::new(); + let scheduler = Scheduler::new(BASE_WORKER_TIMEOUT_S); let action_digest = DigestInfo::new([99u8; 32], 512); let mut platform_properties = PlatformProperties::default(); platform_properties @@ -122,7 +125,7 @@ mod buf_channel_tests { let mut rx_from_worker1 = { let (tx, rx) = mpsc::unbounded_channel(); scheduler - .add_worker(Worker::new(WORKER_ID1, platform_properties.clone(), tx)) + .add_worker(Worker::new(WORKER_ID1, platform_properties.clone(), tx, NOW_TIME)) .await .err_tip(|| "Failed to add worker")?; rx @@ -151,7 +154,7 @@ mod buf_channel_tests { let mut rx_from_worker2 = { let (tx, rx) = mpsc::unbounded_channel(); scheduler - .add_worker(Worker::new(WORKER_ID2, worker_properties, tx)) + .add_worker(Worker::new(WORKER_ID2, worker_properties, tx, NOW_TIME)) .await .err_tip(|| "Failed to add worker")?; rx @@ -194,7 +197,7 @@ mod buf_channel_tests { async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { const WORKER_ID: WorkerId = WorkerId(0x100009); - let scheduler = Scheduler::new(); + let scheduler = Scheduler::new(BASE_WORKER_TIMEOUT_S); let action_digest = DigestInfo::new([99u8; 32], 512); let mut expected_action_state = ActionState { @@ -228,7 +231,7 @@ mod buf_channel_tests { let mut rx_from_worker = { let (tx, rx) = mpsc::unbounded_channel(); scheduler - .add_worker(Worker::new(WORKER_ID, PlatformProperties::default(), tx)) + .add_worker(Worker::new(WORKER_ID, PlatformProperties::default(), tx, NOW_TIME)) .await .err_tip(|| "Failed to add worker")?; rx @@ -275,14 +278,14 @@ mod buf_channel_tests { #[tokio::test] async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), Error> { const WORKER_ID: WorkerId = WorkerId(0x100010); - let scheduler = Scheduler::new(); + let scheduler = Scheduler::new(BASE_WORKER_TIMEOUT_S); let action_digest = DigestInfo::new([99u8; 32], 512); let platform_properties = PlatformProperties::default(); let rx_from_worker = { let (tx, rx) = mpsc::unbounded_channel(); scheduler - .add_worker(Worker::new(WORKER_ID, platform_properties.clone(), tx)) + .add_worker(Worker::new(WORKER_ID, platform_properties.clone(), tx, NOW_TIME)) .await .err_tip(|| "Failed to add worker")?; rx @@ -310,4 +313,106 @@ mod buf_channel_tests { Ok(()) } + + #[tokio::test] + async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { + const WORKER_ID1: WorkerId = WorkerId(0x111111); + const WORKER_ID2: WorkerId = WorkerId(0x222222); + let scheduler = Scheduler::new(BASE_WORKER_TIMEOUT_S); + let action_digest = DigestInfo::new([99u8; 32], 512); + let platform_properties = PlatformProperties::default(); + + // Note: This needs to stay in scope or a disconnect will trigger. + let mut rx_from_worker1 = { + let (tx1, rx1) = mpsc::unbounded_channel(); + scheduler + .add_worker(Worker::new(WORKER_ID1, platform_properties.clone(), tx1, NOW_TIME)) + .await + .err_tip(|| "Failed to add worker1")?; + rx1 + }; + verify_initial_connection_message(WORKER_ID1, &mut rx_from_worker1).await; + + let mut client_rx = { + let mut action_info = make_base_action_info(); + action_info.digest = action_digest.clone(); + scheduler.add_action(action_info).await? + }; + + // Note: This needs to stay in scope or a disconnect will trigger. + let mut rx_from_worker2 = { + let (tx2, rx2) = mpsc::unbounded_channel(); + scheduler + .add_worker(Worker::new(WORKER_ID2, platform_properties.clone(), tx2, NOW_TIME)) + .await + .err_tip(|| "Failed to add worker2")?; + rx2 + }; + verify_initial_connection_message(WORKER_ID2, &mut rx_from_worker2).await; + + let mut expected_action_state = ActionState { + // Name is a random string, so we ignore it and just make it the same. + name: "UNKNOWN_HERE".to_string(), + action_digest: action_digest.clone(), + stage: ActionStage::Executing, + }; + + let execution_request_for_worker = UpdateForWorker { + update: Some(update_for_worker::Update::StartAction(StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + skip_cache_lookup: true, + action_digest: Some(action_digest.clone().into()), + ..Default::default() + }), + })), + }; + + { + // Worker1 should now see execution request. + let msg_for_worker = rx_from_worker1.recv().await.unwrap(); + assert_eq!(msg_for_worker, execution_request_for_worker); + } + + { + // Client should get notification saying it's being executed. + let action_state = client_rx.borrow_and_update(); + // We now know the name of the action so populate it. + expected_action_state.name = action_state.name.clone(); + assert_eq!(action_state.as_ref(), &expected_action_state); + } + + // Keep worker 2 alive. + scheduler + .worker_keep_alive_received(&WORKER_ID2, NOW_TIME + BASE_WORKER_TIMEOUT_S) + .await?; + // This should remove worker 1 (the one executing our job). + scheduler + .remove_timedout_workers(NOW_TIME + BASE_WORKER_TIMEOUT_S) + .await?; + + { + // Worker1 should have received a disconnect message. + let msg_for_worker = rx_from_worker1.recv().await.unwrap(); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::Disconnect(())) + } + ); + } + { + // Client should get notification saying it's being executed. + let action_state = client_rx.borrow_and_update(); + expected_action_state.stage = ActionStage::Executing; + assert_eq!(action_state.as_ref(), &expected_action_state); + } + { + // Worker2 should now see execution request. + let msg_for_worker = rx_from_worker2.recv().await.unwrap(); + assert_eq!(msg_for_worker, execution_request_for_worker); + } + + Ok(()) + } } diff --git a/cas/scheduler/worker.rs b/cas/scheduler/worker.rs index b6d2216aa..ba5f1250a 100644 --- a/cas/scheduler/worker.rs +++ b/cas/scheduler/worker.rs @@ -12,6 +12,8 @@ use proto::com::github::allada::turbo_cache::remote_execution::{ }; use tokio::sync::mpsc::UnboundedSender; +pub type WorkerTimestamp = u64; + /// Unique id of worker. #[derive(Eq, PartialEq, Hash, Copy, Clone)] pub struct WorkerId(pub u128); @@ -50,6 +52,9 @@ impl TryFrom for WorkerId { pub enum WorkerUpdate { /// Requests that the worker begin executing this action. RunAction(Arc), + + /// Request that the worker is no longer in the pool and may discard any jobs. + Disconnect, } /// Represents a connection to a worker and used as the medium to @@ -63,14 +68,29 @@ pub struct Worker { /// Channel to send commands from scheduler to worker. pub tx: UnboundedSender, + + /// The action info of the running action if worker is assigned one. + pub running_action_info: Option>, + + /// Timestamp of last time this worker had been communicated with. + // Warning: Do not update this timestamp without updating the placement of the worker in + // the LRUCache in the Workers struct. + pub last_update_timestamp: WorkerTimestamp, } impl Worker { - pub fn new(id: WorkerId, platform_properties: PlatformProperties, tx: UnboundedSender) -> Self { + pub fn new( + id: WorkerId, + platform_properties: PlatformProperties, + tx: UnboundedSender, + timestamp: WorkerTimestamp, + ) -> Self { Self { id, platform_properties, tx, + running_action_info: None, + last_update_timestamp: timestamp, } } @@ -87,9 +107,15 @@ impl Worker { pub fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> { match worker_update { WorkerUpdate::RunAction(action_info) => self.run_action(action_info), + WorkerUpdate::Disconnect => self.send_msg_to_worker(update_for_worker::Update::Disconnect(())), } } + pub fn keep_alive(&mut self) -> Result<(), Error> { + self.send_msg_to_worker(update_for_worker::Update::KeepAlive(())) + .err_tip(|| format!("Failed to send KeepAlive to worker : {}", self.id)) + } + fn send_msg_to_worker(&mut self, msg: update_for_worker::Update) -> Result<(), Error> { self.tx .send(UpdateForWorker { update: Some(msg) }) @@ -97,12 +123,16 @@ impl Worker { } fn run_action(&mut self, action_info: Arc) -> Result<(), Error> { + self.running_action_info = Some(action_info.clone()); self.reduce_platform_properties(&action_info.platform_properties); self.send_msg_to_worker(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(action_info.as_ref().into()), + execute_request: Some(self.running_action_info.as_ref().unwrap().as_ref().into()), })) } + /// Reduces the platform properties available on the worker based on the platform properties provided. + /// This is used because we allow more than 1 job to run on a worker at a time, and this is how the + /// scheduler knows if more jobs can run on a given worker. fn reduce_platform_properties(&mut self, props: &PlatformProperties) { debug_assert!(props.is_satisfied_by(&self.platform_properties)); for (property, prop_value) in &props.properties { diff --git a/config/cas_server.rs b/config/cas_server.rs index ee3793c51..8501b21a1 100644 --- a/config/cas_server.rs +++ b/config/cas_server.rs @@ -82,6 +82,12 @@ pub struct ExecutionConfig { /// This store name referenced here may be reused multiple times. /// This value must be a CAS store reference. pub cas_store: StoreRefName, + + /// Remove workers from pool once the worker has not responded in this + /// amount of time in seconds. + /// Default: 5 (seconds) + #[serde(default)] + pub worker_timeout_s: u64, } #[derive(Deserialize, Debug)] diff --git a/proto/com/github/allada/turbo_cache/remote_execution/worker_api.proto b/proto/com/github/allada/turbo_cache/remote_execution/worker_api.proto index dcb049a59..9194b9d81 100644 --- a/proto/com/github/allada/turbo_cache/remote_execution/worker_api.proto +++ b/proto/com/github/allada/turbo_cache/remote_execution/worker_api.proto @@ -25,7 +25,7 @@ service WorkerApi { /// may close the connection if the worker has not sent any messages /// after some amount of time (configured in the scheduler's /// configuration). - rpc KeepAlive(google.protobuf.Empty) returns (google.protobuf.Empty); + rpc KeepAlive(KeepAliveRequest) returns (google.protobuf.Empty); /// Informs the scheduler that the service is going offline and /// should stop issuing any new actions on this worker. @@ -38,12 +38,26 @@ service WorkerApi { /// Any job that was running on this instance likely needs to be /// executed again, but up to the scheduler on how or when to handle /// this case. - rpc GoingAway(google.protobuf.Empty) returns (google.protobuf.Empty); + rpc GoingAway(GoingAwayRequest) returns (google.protobuf.Empty); /// Informs the scheduler about the result of an execution request. rpc ExecutionResponse(ExecuteResult) returns (google.protobuf.Empty); } +/// Request object for keep alive requests. +message KeepAliveRequest { + /// ID of the worker making the request. + string worker_id = 1; + reserved 2; // NextId. +} + +/// Request object for going away requests. +message GoingAwayRequest { + /// ID of the worker making the request. + string worker_id = 1; + reserved 2; // NextId. +} + /// Represents the initial request sent to the scheduler informing the /// scheduler about this worker's capabilities. message SupportedProperties { @@ -63,10 +77,11 @@ message SupportedProperties { /// Represents the result of an execution. message ExecuteResult { + string worker_id = 1; /// Result of the execution. See `build.bazel.remote.execution.v2.ExecuteResponse` /// for details. - build.bazel.remote.execution.v2.ExecuteResponse execute_response = 1; - reserved 2; // NextId. + build.bazel.remote.execution.v2.ExecuteResponse execute_response = 2; + reserved 3; // NextId. } /// Result sent back from the server when a node connects. @@ -93,8 +108,12 @@ message UpdateForWorker { /// Informs the worker about some work it should begin performing the /// requested action. StartExecute start_action = 3; + + /// Informs the worker that it has been disconnected from the pool. + /// The worker may discard any outstanding work that is being executed. + google.protobuf.Empty disconnect = 4; } - reserved 4; // NextId. + reserved 5; // NextId. } message StartExecute { diff --git a/proto/genproto/com.github.allada.turbo_cache.remote_execution.pb.rs b/proto/genproto/com.github.allada.turbo_cache.remote_execution.pb.rs index d2a93aa8c..8166ab42c 100644 --- a/proto/genproto/com.github.allada.turbo_cache.remote_execution.pb.rs +++ b/proto/genproto/com.github.allada.turbo_cache.remote_execution.pb.rs @@ -1,4 +1,18 @@ // Copyright 2020 Nathan (Blaise) Bruer. All rights reserved. +//// Request object for keep alive requests. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct KeepAliveRequest { + //// ID of the worker making the request. + #[prost(string, tag = "1")] + pub worker_id: ::prost::alloc::string::String, +} +//// Request object for going away requests. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GoingAwayRequest { + //// ID of the worker making the request. + #[prost(string, tag = "1")] + pub worker_id: ::prost::alloc::string::String, +} //// Represents the initial request sent to the scheduler informing the //// scheduler about this worker's capabilities. #[derive(Clone, PartialEq, ::prost::Message)] @@ -21,9 +35,11 @@ pub struct SupportedProperties { //// Represents the result of an execution. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecuteResult { + #[prost(string, tag = "1")] + pub worker_id: ::prost::alloc::string::String, //// Result of the execution. See `build.bazel.remote.execution.v2.ExecuteResponse` //// for details. - #[prost(message, optional, tag = "1")] + #[prost(message, optional, tag = "2")] pub execute_response: ::core::option::Option< super::super::super::super::super::build::bazel::remote::execution::v2::ExecuteResponse, >, @@ -38,7 +54,7 @@ pub struct ConnectionResult { //// Communication from the scheduler to the worker. #[derive(Clone, PartialEq, ::prost::Message)] pub struct UpdateForWorker { - #[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3")] + #[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4")] pub update: ::core::option::Option, } /// Nested message and enum types in `UpdateForWorker`. @@ -60,6 +76,10 @@ pub mod update_for_worker { //// requested action. #[prost(message, tag = "3")] StartAction(super::StartExecute), + //// Informs the worker that it has been disconnected from the pool. + //// The worker may discard any outstanding work that is being executed. + #[prost(message, tag = "4")] + Disconnect(()), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -165,7 +185,7 @@ pub mod worker_api_client { #[doc = "/ configuration)."] pub async fn keep_alive( &mut self, - request: impl tonic::IntoRequest<()>, + request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { self.inner.ready().await.map_err(|e| { tonic::Status::new( @@ -192,7 +212,7 @@ pub mod worker_api_client { #[doc = "/ this case."] pub async fn going_away( &mut self, - request: impl tonic::IntoRequest<()>, + request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { self.inner.ready().await.map_err(|e| { tonic::Status::new( @@ -250,7 +270,7 @@ pub mod worker_api_server { #[doc = "/ configuration)."] async fn keep_alive( &self, - request: tonic::Request<()>, + request: tonic::Request, ) -> Result, tonic::Status>; #[doc = "/ Informs the scheduler that the service is going offline and"] #[doc = "/ should stop issuing any new actions on this worker."] @@ -265,7 +285,7 @@ pub mod worker_api_server { #[doc = "/ this case."] async fn going_away( &self, - request: tonic::Request<()>, + request: tonic::Request, ) -> Result, tonic::Status>; #[doc = "/ Informs the scheduler about the result of an execution request."] async fn execution_response( @@ -358,10 +378,13 @@ pub mod worker_api_server { "/com.github.allada.turbo_cache.remote_execution.WorkerApi/KeepAlive" => { #[allow(non_camel_case_types)] struct KeepAliveSvc(pub Arc); - impl tonic::server::UnaryService<()> for KeepAliveSvc { + impl tonic::server::UnaryService for KeepAliveSvc { type Response = (); type Future = BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request<()>) -> Self::Future { + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { let inner = self.0.clone(); let fut = async move { (*inner).keep_alive(request).await }; Box::pin(fut) @@ -386,10 +409,13 @@ pub mod worker_api_server { "/com.github.allada.turbo_cache.remote_execution.WorkerApi/GoingAway" => { #[allow(non_camel_case_types)] struct GoingAwaySvc(pub Arc); - impl tonic::server::UnaryService<()> for GoingAwaySvc { + impl tonic::server::UnaryService for GoingAwaySvc { type Response = (); type Future = BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request<()>) -> Self::Future { + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { let inner = self.0.clone(); let fut = async move { (*inner).going_away(request).await }; Box::pin(fut)