Skip to content

Commit

Permalink
Remove mutexes around boxed services (#2947)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte authored Sep 29, 2024
1 parent 3eb8854 commit fb4b189
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 119 deletions.
3 changes: 0 additions & 3 deletions axum/clippy.toml

This file was deleted.

17 changes: 8 additions & 9 deletions axum/src/boxed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{convert::Infallible, fmt};

use crate::extract::Request;
use crate::util::AxumMutex;
use tower::Service;

use crate::{
Expand All @@ -10,7 +9,7 @@ use crate::{
Router,
};

pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
pub(crate) struct BoxedIntoRoute<S, E>(Box<dyn ErasedIntoRoute<S, E>>);

impl<S> BoxedIntoRoute<S, Infallible>
where
Expand All @@ -21,10 +20,10 @@ where
H: Handler<T, S>,
T: 'static,
{
Self(AxumMutex::new(Box::new(MakeErasedHandler {
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
})))
}))
}
}

Expand All @@ -36,20 +35,20 @@ impl<S, E> BoxedIntoRoute<S, E> {
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
E2: 'static,
{
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
inner: self.0.into_inner().unwrap(),
BoxedIntoRoute(Box::new(Map {
inner: self.0,
layer: Box::new(f),
})))
}))
}

pub(crate) fn into_route(self, state: S) -> Route<E> {
self.0.into_inner().unwrap().into_route(state)
self.0.into_route(state)
}
}

impl<S, E> Clone for BoxedIntoRoute<S, E> {
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
Self(self.0.clone_box())
}
}

Expand Down
11 changes: 5 additions & 6 deletions axum/src/routing/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
body::{Body, HttpBody},
box_clone_service::BoxCloneService,
response::Response,
util::AxumMutex,
};
use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes;
Expand All @@ -29,7 +28,7 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
pub struct Route<E = Infallible>(BoxCloneService<Request, Response, E>);

impl<E> Route<E> {
pub(crate) fn new<T>(svc: T) -> Self
Expand All @@ -38,16 +37,16 @@ impl<E> Route<E> {
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
Self(AxumMutex::new(BoxCloneService::new(
Self(BoxCloneService::new(
svc.map_response(IntoResponse::into_response),
)))
))
}

pub(crate) fn oneshot_inner(
&mut self,
req: Request,
) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
self.0.get_mut().unwrap().clone().oneshot(req)
self.0.clone().oneshot(req)
}

pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
Expand All @@ -73,7 +72,7 @@ impl<E> Route<E> {
impl<E> Clone for Route<E> {
#[track_caller]
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
Self(self.0.clone())
}
}

Expand Down
33 changes: 0 additions & 33 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::{
tracing_helpers::{capture_tracing, TracingEvent},
*,
},
util::mutex_num_locked,
BoxError, Extension, Json, Router, ServiceExt,
};
use axum_core::extract::Request;
Expand Down Expand Up @@ -1068,35 +1067,3 @@ async fn impl_handler_for_into_response() {
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(res.text().await, "thing created");
}

#[crate::test]
async fn locks_mutex_very_little() {
let (num, app) = mutex_num_locked(|| async {
Router::new()
.route("/a", get(|| async {}))
.route("/b", get(|| async {}))
.route("/c", get(|| async {}))
.with_state::<()>(())
.into_service::<Body>()
})
.await;
// once for `Router::new` for setting the default fallback and 3 times, once per route
assert_eq!(num, 4);

for path in ["/a", "/b", "/c"] {
// calling the router should only lock the mutex once
let (num, _res) = mutex_num_locked(|| async {
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
// connection and `mutex_num_locked` uses a task local to keep track of the number of
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
//
// So instead `call` the service directly without spawning new tasks.
app.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await
.unwrap()
})
.await;
assert_eq!(num, 1);
}
}
9 changes: 4 additions & 5 deletions axum/src/test_helpers/tracing_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::util::AxumMutex;
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
sync::{Arc, Mutex},
};

use serde::{de::DeserializeOwned, Deserialize};
Expand Down Expand Up @@ -87,12 +86,12 @@ where
}

struct TestMakeWriter {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}

impl TestMakeWriter {
fn new() -> (Self, Handle) {
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));

(
Self {
Expand Down Expand Up @@ -134,7 +133,7 @@ impl<'a> io::Write for Writer<'a> {
}

struct Handle {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}

impl Handle {
Expand Down
63 changes: 0 additions & 63 deletions axum/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use pin_project_lite::pin_project;
use std::{ops::Deref, sync::Arc};

pub(crate) use self::mutex::*;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedStr(Arc<str>);

Expand Down Expand Up @@ -57,64 +55,3 @@ fn test_try_downcast() {
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
}

// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
// times it's been locked on the current task. That way we can write a test to ensure we don't
// accidentally introduce more locking.
//
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
#[cfg(not(test))]
mod mutex {
#[allow(clippy::disallowed_types)]
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
}

#[cfg(test)]
#[allow(clippy::disallowed_types)]
mod mutex {
use std::sync::{
atomic::{AtomicUsize, Ordering},
LockResult, Mutex, MutexGuard,
};

tokio::task_local! {
pub(crate) static NUM_LOCKED: AtomicUsize;
}

pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
where
F: FnOnce() -> Fut,
Fut: std::future::IntoFuture,
{
NUM_LOCKED
.scope(AtomicUsize::new(0), async move {
let output = f().await;
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
(num, output)
})
.await
}

pub(crate) struct AxumMutex<T>(Mutex<T>);

impl<T> AxumMutex<T> {
pub(crate) fn new(value: T) -> Self {
Self(Mutex::new(value))
}

pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
self.0.get_mut()
}

pub(crate) fn into_inner(self) -> LockResult<T> {
self.0.into_inner()
}

pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
_ = NUM_LOCKED.try_with(|num| {
num.fetch_add(1, Ordering::SeqCst);
});
self.0.lock()
}
}
}

0 comments on commit fb4b189

Please sign in to comment.