Skip to content

Commit

Permalink
Add wait_drained to SchedulerServer and Executor (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
mpurins-coralogix authored Mar 22, 2023
1 parent ae006d7 commit 0637711
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 44 deletions.
1 change: 0 additions & 1 deletion ballista/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ mod tests {

fn get_test_partition_locations(n: usize, path: String) -> Vec<PartitionLocation> {
(0..n)
.into_iter()
.map(|partition_id| PartitionLocation {
map_partition_id: 0,
partition_id: PartitionId {
Expand Down
74 changes: 46 additions & 28 deletions ballista/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@

use dashmap::DashMap;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::metrics::ExecutorMetricsCollector;
use ballista_core::error::BallistaError;
Expand All @@ -37,23 +34,10 @@ use datafusion::physical_plan::udaf::AggregateUDF;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use futures::future::AbortHandle;
use tokio::sync::watch;

use ballista_core::serde::scheduler::PartitionId;

pub struct TasksDrainedFuture(pub Arc<Executor>);

impl Future for TasksDrainedFuture {
type Output = ();

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.abort_handles.len() > 0 {
Poll::Pending
} else {
Poll::Ready(())
}
}
}

type AbortHandles = Arc<DashMap<(usize, PartitionId), AbortHandle>>;

/// Ballista executor
Expand Down Expand Up @@ -82,6 +66,9 @@ pub struct Executor {

/// Handles to abort executing tasks
abort_handles: AbortHandles,

drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

impl Executor {
Expand All @@ -93,17 +80,15 @@ impl Executor {
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
) -> Self {
Self {
Self::with_functions(
metadata,
work_dir: work_dir.to_owned(),
// TODO add logic to dynamically load UDF/UDAFs libs from files
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
work_dir,
runtime,
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
}
HashMap::new(),
HashMap::new(),
)
}

pub fn with_functions(
Expand All @@ -115,6 +100,8 @@ impl Executor {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
) -> Self {
let (drained, check_drained) = watch::channel(());

Self {
metadata,
work_dir: work_dir.to_owned(),
Expand All @@ -124,6 +111,8 @@ impl Executor {
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
drained: Arc::new(drained),
check_drained,
}
}
}
Expand All @@ -147,9 +136,11 @@ impl Executor {
self.abort_handles
.insert((task_id, partition.clone()), abort_handle);

let partitions = task.await??;
let partitions = task.await;

self.remove_handle(task_id, partition.clone());

self.abort_handles.remove(&(task_id, partition.clone()));
let partitions = partitions??;

self.metrics_collector.record_stage(
&partition.job_id,
Expand Down Expand Up @@ -196,14 +187,14 @@ impl Executor {
stage_id: usize,
partition_id: usize,
) -> Result<bool, BallistaError> {
if let Some((_, handle)) = self.abort_handles.remove(&(
if let Some((_, handle)) = self.remove_handle(
task_id,
PartitionId {
job_id,
stage_id,
partition_id,
},
)) {
) {
handle.abort();
Ok(true)
} else {
Expand All @@ -218,6 +209,33 @@ impl Executor {
pub fn active_task_count(&self) -> usize {
self.abort_handles.len()
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();
loop {
if self.active_task_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}

fn remove_handle(
&self,
task_id: usize,
partition: PartitionId,
) -> Option<((usize, PartitionId), AbortHandle)> {
let removed = self.abort_handles.remove(&(task_id, partition));

if self.active_task_count() == 0 {
self.drained.send_replace(());
}

removed
}
}

#[cfg(test)]
Expand Down
4 changes: 2 additions & 2 deletions ballista/executor/src/executor_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use ballista_core::utils::{
};
use ballista_core::BALLISTA_VERSION;

use crate::executor::{Executor, TasksDrainedFuture};
use crate::executor::Executor;
use crate::executor_server::TERMINATING;
use crate::flight_service::BallistaFlightService;
use crate::metrics::LoggingMetricsCollector;
Expand Down Expand Up @@ -301,7 +301,7 @@ pub async fn start_executor_process(opt: ExecutorProcessConfig) -> Result<()> {
shutdown_noti.subscribe_for_shutdown(),
)));

let tasks_drained = TasksDrainedFuture(executor);
let tasks_drained = executor.wait_drained();

// Concurrently run the service checking and listen for the `shutdown` signal and wait for the stop request coming.
// The check_services runs until an error is encountered, so under normal circumstances, this `select!` statement runs
Expand Down
6 changes: 3 additions & 3 deletions ballista/scheduler/src/cluster/event/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ mod test {
}
});

let expected: Vec<i32> = (0..100).into_iter().collect();
let expected: Vec<i32> = (0..100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down Expand Up @@ -233,7 +233,7 @@ mod test {

// When we reach capacity older events should be dropped so we only see
// the last 8 events in our subscribers
let expected: Vec<i32> = (92..100).into_iter().collect();
let expected: Vec<i32> = (92..100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down Expand Up @@ -271,7 +271,7 @@ mod test {
}
});

let expected: Vec<i32> = (1..=100).into_iter().collect();
let expected: Vec<i32> = (1..=100).collect();

let results = handle.await.unwrap();
assert_eq!(results.len(), 3);
Expand Down
4 changes: 4 additions & 0 deletions ballista/scheduler/src/scheduler_server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
pub fn session_manager(&self) -> SessionManager {
self.state.session_manager.clone()
}

pub async fn wait_drained(&self) {
self.state.task_manager.wait_drained().await;
}
}

pub fn timestamp_secs() -> u64 {
Expand Down
42 changes: 34 additions & 8 deletions ballista/scheduler/src/state/task_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tokio::sync::{watch, RwLock};

use crate::scheduler_server::timestamp_millis;
use tracing::trace;
Expand Down Expand Up @@ -115,6 +115,8 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
// Cache for active jobs curated by this scheduler
active_job_cache: ActiveJobCache,
launcher: Arc<dyn TaskLauncher>,
drained: Arc<watch::Sender<()>>,
check_drained: watch::Receiver<()>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -149,13 +151,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
codec: BallistaCodec<T, U>,
scheduler_id: String,
) -> Self {
Self {
Self::with_launcher(
state,
codec,
scheduler_id: scheduler_id.clone(),
active_job_cache: Arc::new(DashMap::new()),
launcher: Arc::new(DefaultTaskLauncher::new(scheduler_id)),
}
scheduler_id.clone(),
Arc::new(DefaultTaskLauncher::new(scheduler_id)),
)
}

#[allow(dead_code)]
Expand All @@ -165,12 +166,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
scheduler_id: String,
launcher: Arc<dyn TaskLauncher>,
) -> Self {
let (drained, check_drained) = watch::channel(());

Self {
state,
codec,
scheduler_id,
active_job_cache: Arc::new(DashMap::new()),
launcher,
drained: Arc::new(drained),
check_drained,
}
}

Expand Down Expand Up @@ -690,9 +695,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
self.active_job_cache
let removed = self
.active_job_cache
.remove(job_id)
.map(|value| value.1.execution_graph)
.map(|value| value.1.execution_graph);

if self.get_active_job_count() == 0 {
self.drained.send_replace(());
}

removed
}

/// Generate a new random Job ID
Expand Down Expand Up @@ -721,6 +733,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
}
});
}

pub async fn wait_drained(&self) {
let mut check_drained = self.check_drained.clone();

loop {
if self.get_active_job_count() == 0 {
break;
}

if check_drained.changed().await.is_err() {
break;
};
}
}
}

pub struct JobOverview {
Expand Down
2 changes: 0 additions & 2 deletions ballista/scheduler/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ pub fn default_task_runner() -> impl TaskRunner {
};

let partitions: Vec<ShuffleWritePartition> = (0..partitions)
.into_iter()
.map(|i| ShuffleWritePartition {
partition_id: i as u64,
path: String::default(),
Expand Down Expand Up @@ -410,7 +409,6 @@ impl SchedulerTest {
let runner = runner.unwrap_or_else(|| Arc::new(default_task_runner()));

let executors: HashMap<String, VirtualExecutor> = (0..num_executors)
.into_iter()
.map(|i| {
let id = format!("virtual-executor-{i}");
let executor = VirtualExecutor {
Expand Down

0 comments on commit 0637711

Please sign in to comment.