Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mutexes around boxed services #2947

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions axum/clippy.toml

This file was deleted.

17 changes: 8 additions & 9 deletions axum/src/boxed.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{convert::Infallible, fmt};

use crate::extract::Request;
use crate::util::AxumMutex;
use tower::Service;

use crate::{
Expand All @@ -10,7 +9,7 @@ use crate::{
Router,
};

pub(crate) struct BoxedIntoRoute<S, E>(AxumMutex<Box<dyn ErasedIntoRoute<S, E>>>);
pub(crate) struct BoxedIntoRoute<S, E>(Box<dyn ErasedIntoRoute<S, E>>);

impl<S> BoxedIntoRoute<S, Infallible>
where
Expand All @@ -21,10 +20,10 @@ where
H: Handler<T, S>,
T: 'static,
{
Self(AxumMutex::new(Box::new(MakeErasedHandler {
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state(handler, state)),
})))
}))
}
}

Expand All @@ -36,20 +35,20 @@ impl<S, E> BoxedIntoRoute<S, E> {
F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
E2: 'static,
{
BoxedIntoRoute(AxumMutex::new(Box::new(Map {
inner: self.0.into_inner().unwrap(),
BoxedIntoRoute(Box::new(Map {
inner: self.0,
layer: Box::new(f),
})))
}))
}

pub(crate) fn into_route(self, state: S) -> Route<E> {
self.0.into_inner().unwrap().into_route(state)
self.0.into_route(state)
}
}

impl<S, E> Clone for BoxedIntoRoute<S, E> {
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone_box()))
Self(self.0.clone_box())
}
}

Expand Down
11 changes: 5 additions & 6 deletions axum/src/routing/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
body::{Body, HttpBody},
box_clone_service::BoxCloneService,
response::Response,
util::AxumMutex,
};
use axum_core::{extract::Request, response::IntoResponse};
use bytes::Bytes;
Expand All @@ -29,7 +28,7 @@ use tower_service::Service;
///
/// You normally shouldn't need to care about this type. It's used in
/// [`Router::layer`](super::Router::layer).
pub struct Route<E = Infallible>(AxumMutex<BoxCloneService<Request, Response, E>>);
pub struct Route<E = Infallible>(BoxCloneService<Request, Response, E>);

impl<E> Route<E> {
pub(crate) fn new<T>(svc: T) -> Self
Expand All @@ -38,16 +37,16 @@ impl<E> Route<E> {
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
{
Self(AxumMutex::new(BoxCloneService::new(
Self(BoxCloneService::new(
svc.map_response(IntoResponse::into_response),
)))
))
}

pub(crate) fn oneshot_inner(
&mut self,
req: Request,
) -> Oneshot<BoxCloneService<Request, Response, E>, Request> {
self.0.get_mut().unwrap().clone().oneshot(req)
self.0.clone().oneshot(req)
}

pub(crate) fn layer<L, NewError>(self, layer: L) -> Route<NewError>
Expand All @@ -73,7 +72,7 @@ impl<E> Route<E> {
impl<E> Clone for Route<E> {
#[track_caller]
fn clone(&self) -> Self {
Self(AxumMutex::new(self.0.lock().unwrap().clone()))
Self(self.0.clone())
}
}

Expand Down
33 changes: 0 additions & 33 deletions axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use crate::{
tracing_helpers::{capture_tracing, TracingEvent},
*,
},
util::mutex_num_locked,
BoxError, Extension, Json, Router, ServiceExt,
};
use axum_core::extract::Request;
Expand Down Expand Up @@ -1068,35 +1067,3 @@ async fn impl_handler_for_into_response() {
assert_eq!(res.status(), StatusCode::CREATED);
assert_eq!(res.text().await, "thing created");
}

#[crate::test]
async fn locks_mutex_very_little() {
let (num, app) = mutex_num_locked(|| async {
Router::new()
.route("/a", get(|| async {}))
.route("/b", get(|| async {}))
.route("/c", get(|| async {}))
.with_state::<()>(())
.into_service::<Body>()
})
.await;
// once for `Router::new` for setting the default fallback and 3 times, once per route
assert_eq!(num, 4);

for path in ["/a", "/b", "/c"] {
// calling the router should only lock the mutex once
let (num, _res) = mutex_num_locked(|| async {
// We cannot use `TestClient` because it uses `serve` which spawns a new task per
// connection and `mutex_num_locked` uses a task local to keep track of the number of
// locks. So spawning a new task would unset the task local set by `mutex_num_locked`
//
// So instead `call` the service directly without spawning new tasks.
app.clone()
.oneshot(Request::builder().uri(path).body(Body::empty()).unwrap())
.await
.unwrap()
})
.await;
assert_eq!(num, 1);
}
}
9 changes: 4 additions & 5 deletions axum/src/test_helpers/tracing_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use crate::util::AxumMutex;
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
sync::{Arc, Mutex},
};

use serde::{de::DeserializeOwned, Deserialize};
Expand Down Expand Up @@ -87,12 +86,12 @@ where
}

struct TestMakeWriter {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}

impl TestMakeWriter {
fn new() -> (Self, Handle) {
let write = Arc::new(AxumMutex::new(Some(Vec::<u8>::new())));
let write = Arc::new(Mutex::new(Some(Vec::<u8>::new())));

(
Self {
Expand Down Expand Up @@ -134,7 +133,7 @@ impl<'a> io::Write for Writer<'a> {
}

struct Handle {
write: Arc<AxumMutex<Option<Vec<u8>>>>,
write: Arc<Mutex<Option<Vec<u8>>>>,
}

impl Handle {
Expand Down
63 changes: 0 additions & 63 deletions axum/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use pin_project_lite::pin_project;
use std::{ops::Deref, sync::Arc};

pub(crate) use self::mutex::*;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct PercentDecodedStr(Arc<str>);

Expand Down Expand Up @@ -57,64 +55,3 @@ fn test_try_downcast() {
assert_eq!(try_downcast::<i32, _>(5_u32), Err(5_u32));
assert_eq!(try_downcast::<i32, _>(5_i32), Ok(5_i32));
}

// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of
// times it's been locked on the current task. That way we can write a test to ensure we don't
// accidentally introduce more locking.
//
// When not in test mode, it is just a type alias for `std::sync::Mutex`.
#[cfg(not(test))]
mod mutex {
#[allow(clippy::disallowed_types)]
pub(crate) type AxumMutex<T> = std::sync::Mutex<T>;
}

#[cfg(test)]
#[allow(clippy::disallowed_types)]
mod mutex {
use std::sync::{
atomic::{AtomicUsize, Ordering},
LockResult, Mutex, MutexGuard,
};

tokio::task_local! {
pub(crate) static NUM_LOCKED: AtomicUsize;
}

pub(crate) async fn mutex_num_locked<F, Fut>(f: F) -> (usize, Fut::Output)
where
F: FnOnce() -> Fut,
Fut: std::future::IntoFuture,
{
NUM_LOCKED
.scope(AtomicUsize::new(0), async move {
let output = f().await;
let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst));
(num, output)
})
.await
}

pub(crate) struct AxumMutex<T>(Mutex<T>);

impl<T> AxumMutex<T> {
pub(crate) fn new(value: T) -> Self {
Self(Mutex::new(value))
}

pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> {
self.0.get_mut()
}

pub(crate) fn into_inner(self) -> LockResult<T> {
self.0.into_inner()
}

pub(crate) fn lock(&self) -> LockResult<MutexGuard<'_, T>> {
_ = NUM_LOCKED.try_with(|num| {
num.fetch_add(1, Ordering::SeqCst);
});
self.0.lock()
}
}
}