Skip to content

Commit

Permalink
All tokio::spawn and related functions must use nativelink's version (#…
Browse files Browse the repository at this point in the history
…890)

We now enforce all locations in our code base to use one of the
`nativelink-util::task` when trying to do a task that might require
an operation that changes threads.
  • Loading branch information
allada authored Apr 26, 2024
1 parent c0d7eaa commit c1d0402
Show file tree
Hide file tree
Showing 31 changed files with 465 additions and 324 deletions.
3 changes: 2 additions & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ build --aspects=@rules_rust//rust:defs.bzl%rustfmt_aspect
build --aspects=@rules_rust//rust:defs.bzl%rust_clippy_aspect

# TODO(aaronmondal): Extend these flags until we can run with clippy::pedantic.
build --@rules_rust//:clippy_flags=-D,clippy::uninlined_format_args
build --@rules_rust//:clippy_flags=-Dwarnings,-Dclippy::uninlined_format_args
build --@rules_rust//:clippy.toml=//:clippy.toml

test --@rules_rust//:rustfmt.toml=//:.rustfmt.toml

Expand Down
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ load("@rules_rust//rust:defs.bzl", "rust_binary")
exports_files(
[
".rustfmt.toml",
"clippy.toml",
],
visibility = ["//visibility:public"],
)
Expand Down
16 changes: 16 additions & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
disallowed-methods = [
{ path = "tokio::spawn", reason = "use `nativelink-util::task::spawn` or `nativelink-util::task::background_spawn` instead" },
{ path = "tokio::task::spawn", reason = "use `nativelink-util::task::spawn` or `nativelink-util::task::background_spawn` instead" },
{ path = "tokio::task::spawn_blocking", reason = "use `nativelink-util::task::spawn_blocking` instead" },
{ path = "tokio::task::block_in_place", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::task::spawn_local", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Builder::new_current_thread", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Builder::new_multi_thread", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Builder::new_multi_thread_alt", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Runtime::new", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Runtime::spawn", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Runtime::spawn_blocking", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "tokio::runtime::Runtime::block_on", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "std::thread::spawn", reason = "use one of the `nativelink-util::task` functions instead" },
{ path = "std::thread::Builder::new", reason = "use one of the `nativelink-util::task` functions instead" },
]
6 changes: 5 additions & 1 deletion nativelink-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream {

let expanded = quote! {
#(#fn_attr)*
#[allow(clippy::disallowed_methods)]
#[tokio::test(#attr)]
async fn #fn_name(#fn_inputs) #fn_output {
#fn_block
#[warn(clippy::disallowed_methods)]
{
#fn_block
}
}
};

Expand Down
3 changes: 2 additions & 1 deletion nativelink-scheduler/src/cache_lookup_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use nativelink_store::grpc_store::GrpcStore;
use nativelink_util::action_messages::{
ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState,
};
use nativelink_util::background_spawn;
use nativelink_util::common::DigestInfo;
use nativelink_util::store_trait::Store;
use parking_lot::{Mutex, MutexGuard};
Expand Down Expand Up @@ -158,7 +159,7 @@ impl ActionScheduler for CacheLookupScheduler {
let ac_store = self.ac_store.clone();
let action_scheduler = self.action_scheduler.clone();
// We need this spawn because we are returning a stream and this spawn will populate the stream's data.
tokio::spawn(async move {
background_spawn!("cache_lookup_scheduler_add_action", async move {
// If our spawn ever dies, we will remove the action from the cache_check_actions map.
let _scope_guard = scope_guard;

Expand Down
3 changes: 2 additions & 1 deletion nativelink-scheduler/src/default_scheduler_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::time::Duration;
use nativelink_config::schedulers::SchedulerConfig;
use nativelink_error::{Error, ResultExt};
use nativelink_store::store_manager::StoreManager;
use nativelink_util::background_spawn;
use nativelink_util::metrics_utils::Registry;
use tokio::time::interval;

Expand Down Expand Up @@ -117,7 +118,7 @@ fn inner_scheduler_factory(

fn start_cleanup_timer(action_scheduler: &Arc<dyn ActionScheduler>) {
let weak_scheduler = Arc::downgrade(action_scheduler);
tokio::spawn(async move {
background_spawn!("default_scheduler_factory_cleanup_timer", async move {
let mut ticker = interval(Duration::from_secs(1));
loop {
ticker.tick().await;
Expand Down
9 changes: 4 additions & 5 deletions nativelink-scheduler/src/grpc_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ use nativelink_util::action_messages::{
};
use nativelink_util::connection_manager::ConnectionManager;
use nativelink_util::retry::{Retrier, RetryResult};
use nativelink_util::tls_utils;
use nativelink_util::{background_spawn, tls_utils};
use parking_lot::Mutex;
use rand::rngs::OsRng;
use rand::Rng;
use tokio::select;
use tokio::sync::watch;
use tokio::time::sleep;
use tonic::{Request, Streaming};
use tracing::{error_span, event, Instrument, Level};
use tracing::{event, Level};

use crate::action_scheduler::ActionScheduler;
use crate::platform_property_manager::PlatformPropertyManager;
Expand Down Expand Up @@ -119,7 +119,7 @@ impl GrpcScheduler {
.err_tip(|| "Recieving response from upstream scheduler")?
{
let (tx, rx) = watch::channel(Arc::new(initial_response.try_into()?));
tokio::spawn(async move {
background_spawn!("grpc_scheduler_stream_state", async move {
loop {
select!(
_ = tx.closed() => {
Expand Down Expand Up @@ -157,8 +157,7 @@ impl GrpcScheduler {
}
)
}
}
.instrument(error_span!("stream_state")));
});
return Ok(rx);
}
Err(make_err!(
Expand Down
59 changes: 29 additions & 30 deletions nativelink-scheduler/src/simple_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ use nativelink_util::metrics_utils::{
MetricsComponent, Registry,
};
use nativelink_util::platform_properties::PlatformPropertyValue;
use nativelink_util::spawn;
use nativelink_util::task::JoinHandleDropGuard;
use parking_lot::{Mutex, MutexGuard};
use tokio::sync::{watch, Notify};
use tokio::task::JoinHandle;
use tokio::time::Duration;
use tracing::{event, Level};

Expand Down Expand Up @@ -689,8 +690,9 @@ impl SimpleSchedulerImpl {
pub struct SimpleScheduler {
inner: Arc<Mutex<SimpleSchedulerImpl>>,
platform_property_manager: Arc<PlatformPropertyManager>,
task_worker_matching_future: JoinHandle<()>,
metrics: Arc<Metrics>,
// Triggers `drop()`` call if scheduler is dropped.
_task_worker_matching_future: JoinHandleDropGuard<()>,
}

impl SimpleScheduler {
Expand Down Expand Up @@ -758,29 +760,32 @@ impl SimpleScheduler {
Self {
inner,
platform_property_manager,
task_worker_matching_future: tokio::spawn(async move {
// Break out of the loop only when the inner is dropped.
loop {
tasks_or_workers_change_notify.notified().await;
match weak_inner.upgrade() {
// Note: According to `parking_lot` documentation, the default
// `Mutex` implementation is eventual fairness, so we don't
// really need to worry about this thread taking the lock
// starving other threads too much.
Some(inner_mux) => {
let mut inner = inner_mux.lock();
let timer = metrics_for_do_try_match.do_try_match.begin_timer();
inner.do_try_match();
timer.measure();
}
// If the inner went away it means the scheduler is shutting
// down, so we need to resolve our future.
None => return,
};
on_matching_engine_run().await;
_task_worker_matching_future: spawn!(
"simple_scheduler_task_worker_matching",
async move {
// Break out of the loop only when the inner is dropped.
loop {
tasks_or_workers_change_notify.notified().await;
match weak_inner.upgrade() {
// Note: According to `parking_lot` documentation, the default
// `Mutex` implementation is eventual fairness, so we don't
// really need to worry about this thread taking the lock
// starving other threads too much.
Some(inner_mux) => {
let mut inner = inner_mux.lock();
let timer = metrics_for_do_try_match.do_try_match.begin_timer();
inner.do_try_match();
timer.measure();
}
// If the inner went away it means the scheduler is shutting
// down, so we need to resolve our future.
None => return,
};
on_matching_engine_run().await;
}
// Unreachable.
}
// Unreachable.
}),
),
metrics,
}
}
Expand Down Expand Up @@ -982,12 +987,6 @@ impl WorkerScheduler for SimpleScheduler {
}
}

impl Drop for SimpleScheduler {
fn drop(&mut self) {
self.task_worker_matching_future.abort();
}
}

impl MetricsComponent for SimpleScheduler {
fn gather_metrics(&self, c: &mut CollectorState) {
self.metrics.gather_metrics(c);
Expand Down
11 changes: 5 additions & 6 deletions nativelink-service/src/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ use nativelink_util::buf_channel::{
use nativelink_util::common::DigestInfo;
use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper;
use nativelink_util::resource_info::ResourceInfo;
use nativelink_util::spawn;
use nativelink_util::store_trait::{Store, UploadSizeInfo};
use nativelink_util::task::JoinHandleDropGuard;
use parking_lot::Mutex;
use tokio::task::AbortHandle;
use tokio::time::sleep;
use tonic::{Request, Response, Status, Streaming};
use tracing::{enabled, error_span, event, instrument, Instrument, Level};
Expand Down Expand Up @@ -110,15 +111,14 @@ impl<'a> Drop for ActiveStreamGuard<'a> {
let sleep_fn = self.bytestream_server.sleep_fn.clone();
active_uploads_slot.1 = Some(IdleStream {
stream_state,
abort_timeout_handle: tokio::spawn(async move {
_timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move {
(*sleep_fn)().await;
if let Some(active_uploads) = weak_active_uploads.upgrade() {
let mut active_uploads = active_uploads.lock();
event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid);
active_uploads.remove(&uuid);
}
})
.abort_handle(),
}),
});
}
}
Expand All @@ -129,7 +129,7 @@ impl<'a> Drop for ActiveStreamGuard<'a> {
#[derive(Debug)]
struct IdleStream {
stream_state: StreamState,
abort_timeout_handle: AbortHandle,
_timeout_streaam_drop_guard: JoinHandleDropGuard<()>,
}

impl IdleStream {
Expand All @@ -138,7 +138,6 @@ impl IdleStream {
bytes_received: Arc<AtomicU64>,
bytestream_server: &ByteStreamServer,
) -> ActiveStreamGuard<'_> {
self.abort_timeout_handle.abort();
ActiveStreamGuard {
stream_state: Some(self.stream_state),
bytes_received,
Expand Down
57 changes: 31 additions & 26 deletions nativelink-service/src/health_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use hyper::{Body, Request, Response, StatusCode};
use nativelink_util::health_utils::{
HealthRegistry, HealthStatus, HealthStatusDescription, HealthStatusReporter,
};
use nativelink_util::task::instrument_future;
use tower::Service;
use tracing::error_span;

/// Content type header value for JSON.
const JSON_CONTENT_TYPE: &str = "application/json; charset=utf-8";
Expand All @@ -49,34 +51,37 @@ impl Service<Request<hyper::Body>> for HealthServer {

fn call(&mut self, _req: Request<Body>) -> Self::Future {
let health_registry = self.health_registry.clone();
Box::pin(async move {
let health_status_descriptions: Vec<HealthStatusDescription> =
health_registry.health_status_report().collect().await;
match serde_json5::to_string(&health_status_descriptions) {
Ok(body) => {
let contains_failed_report =
health_status_descriptions.iter().any(|description| {
matches!(description.status, HealthStatus::Failed { .. })
});
let status_code = if contains_failed_report {
StatusCode::SERVICE_UNAVAILABLE
} else {
StatusCode::OK
};
Box::pin(instrument_future(
async move {
let health_status_descriptions: Vec<HealthStatusDescription> =
health_registry.health_status_report().collect().await;
match serde_json5::to_string(&health_status_descriptions) {
Ok(body) => {
let contains_failed_report =
health_status_descriptions.iter().any(|description| {
matches!(description.status, HealthStatus::Failed { .. })
});
let status_code = if contains_failed_report {
StatusCode::SERVICE_UNAVAILABLE
} else {
StatusCode::OK
};

Ok(Response::builder()
.status(status_code)
Ok(Response::builder()
.status(status_code)
.header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE))
.body(Body::from(body))
.unwrap())
}

Err(e) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE))
.body(Body::from(body))
.unwrap())
.body(Body::from(format!("Internal Failure: {e:?}")))
.unwrap()),
}

Err(e) => Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE))
.body(Body::from(format!("Internal Failure: {e:?}")))
.unwrap()),
}
})
},
error_span!("health_server_call"),
))
}
}
13 changes: 5 additions & 8 deletions nativelink-service/src/worker_api_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::
};
use nativelink_scheduler::worker::{Worker, WorkerId};
use nativelink_scheduler::worker_scheduler::WorkerScheduler;
use nativelink_util::background_spawn;
use nativelink_util::action_messages::ActionInfoHashKey;
use nativelink_util::common::DigestInfo;
use nativelink_util::platform_properties::PlatformProperties;
use tokio::sync::mpsc;
use tokio::time::interval;
use tonic::{Request, Response, Status};
use tracing::{error_span, event, instrument, Instrument, Level};
use tracing::{event, instrument, Level};
use uuid::Uuid;

pub type ConnectWorkerStream =
Expand All @@ -58,7 +59,7 @@ impl WorkerApiServer {
// event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will
// eventually see the Arc went away and return.
let weak_scheduler = Arc::downgrade(scheduler);
tokio::spawn(async move {
background_spawn!("worker_api_server", async move {
let mut ticker = interval(Duration::from_secs(1));
loop {
ticker.tick().await;
Expand All @@ -70,18 +71,14 @@ impl WorkerApiServer {
if let Err(err) =
scheduler.remove_timedout_workers(timestamp.as_secs()).await
{
event!(
Level::ERROR,
?err,
"Failed to remove_timedout_workers",
);
event!(Level::ERROR, ?err, "Failed to remove_timedout_workers",);
}
}
// If we fail to upgrade, our service is probably destroyed, so return.
None => return,
}
}
}.instrument(error_span!("worker_api_server")));
});
}

Self::new_with_now_fn(
Expand Down
Loading

0 comments on commit c1d0402

Please sign in to comment.