Skip to content

Commit

Permalink
Remove duplication in serving with and without graceful shutdown (#2803)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte authored Sep 27, 2024
1 parent 4b48f30 commit 5d8541d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 84 deletions.
3 changes: 2 additions & 1 deletion axum/src/routing/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ async fn logging_rejections() {
rejection_type: String,
}

let events = capture_tracing::<RejectionEvent, _, _>(|| async {
let events = capture_tracing::<RejectionEvent, _>(|| async {
let app = Router::new()
.route("/extension", get(|_: Extension<Infallible>| async {}))
.route("/string", post(|_: String| async {}));
Expand All @@ -987,6 +987,7 @@ async fn logging_rejections() {
StatusCode::BAD_REQUEST,
);
})
.with_filter("axum::rejection=trace")
.await;

assert_eq!(
Expand Down
57 changes: 2 additions & 55 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,61 +213,8 @@ where
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
private::ServeFuture(Box::pin(async move {
let Self {
tcp_listener,
mut make_service,
tcp_nodelay,
_marker: _,
} = self;

loop {
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};

if let Some(nodelay) = tcp_nodelay {
if let Err(err) = tcp_stream.set_nodelay(nodelay) {
trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
}

let tcp_stream = TokioIo::new(tcp_stream);

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));

let hyper_service = TowerToHyperService::new(tower_service);

tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
// This error only appears when the client doesn't send a request and
// terminate the connection.
//
// If client sends one request then terminate connection whenever, it doesn't
// appear.
}
}
});
}
}))
self.with_graceful_shutdown(std::future::pending())
.into_future()
}
}

Expand Down
96 changes: 68 additions & 28 deletions axum/src/test_helpers/tracing_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
use crate::util::AxumMutex;
use std::{future::Future, io, sync::Arc};
use std::{
future::{Future, IntoFuture},
io,
marker::PhantomData,
pin::Pin,
sync::Arc,
};

use serde::{de::DeserializeOwned, Deserialize};
use tracing::instrument::WithSubscriber;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{filter::Targets, fmt::MakeWriter};

Expand All @@ -14,36 +21,69 @@ pub(crate) struct TracingEvent<T> {
}

/// Run an async closure and capture the tracing output it produces.
pub(crate) async fn capture_tracing<T, F, Fut>(f: F) -> Vec<TracingEvent<T>>
pub(crate) fn capture_tracing<T, F>(f: F) -> CaptureTracing<T, F>
where
F: Fn() -> Fut,
Fut: Future,
T: DeserializeOwned,
{
let (make_writer, handle) = TestMakeWriter::new();

let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter("axum=trace".parse::<Targets>().unwrap()),
);

let guard = tracing::subscriber::set_default(subscriber);

f().await;

drop(guard);

handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
CaptureTracing {
f,
filter: None,
_phantom: PhantomData,
}
}

pub(crate) struct CaptureTracing<T, F> {
f: F,
filter: Option<Targets>,
_phantom: PhantomData<fn() -> T>,
}

impl<T, F> CaptureTracing<T, F> {
pub(crate) fn with_filter(mut self, filter_string: &str) -> Self {
self.filter = Some(filter_string.parse().unwrap());
self
}
}

impl<T, F, Fut> IntoFuture for CaptureTracing<T, F>
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future + Send,
T: DeserializeOwned,
{
type Output = Vec<TracingEvent<T>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;

fn into_future(self) -> Self::IntoFuture {
let Self { f, filter, .. } = self;
Box::pin(async move {
let (make_writer, handle) = TestMakeWriter::new();

let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap());
let subscriber = tracing_subscriber::registry().with(
tracing_subscriber::fmt::layer()
.with_writer(make_writer)
.with_target(true)
.without_time()
.with_ansi(false)
.json()
.flatten_event(false)
.with_filter(filter),
);

let guard = tracing::subscriber::set_default(subscriber);

f().with_current_subscriber().await;

drop(guard);

handle
.take()
.lines()
.map(|line| serde_json::from_str(line).unwrap())
.collect()
})
}
}

struct TestMakeWriter {
Expand Down

0 comments on commit 5d8541d

Please sign in to comment.