From 72871dc30729f79df91fcb4803d9e5160dca7180 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Fri, 27 Sep 2024 23:57:26 +0200 Subject: [PATCH] Make `serve` generic over the listener and IO types Co-authored-by: David Pedersen --- axum/CHANGELOG.md | 2 + axum/Cargo.toml | 1 + .../into_make_service_with_connect_info.md | 5 +- axum/src/extract/connect_info.rs | 15 +- axum/src/handler/service.rs | 7 +- axum/src/routing/method_routing.rs | 9 +- axum/src/routing/mod.rs | 9 +- axum/src/serve.rs | 335 +++++++++++++----- examples/unix-domain-socket/src/main.rs | 56 +-- 9 files changed, 297 insertions(+), 142 deletions(-) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 0298701849..d42e00565b 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -8,8 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased - **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931]) +- **breaking:** Make `serve` generic over the listener and IO types ([#2941]) [#2931]: https://github.com/tokio-rs/axum/pull/2931 +[#2941]: https://github.com/tokio-rs/axum/pull/2941 # 0.7.7 diff --git a/axum/Cargo.toml b/axum/Cargo.toml index ad593adcd1..c6133c720f 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -113,6 +113,7 @@ features = [ [dev-dependencies] anyhow = "1.0" axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] } +hyper = { version = "1.1.0", features = ["client"] } quickcheck = "1.0" quickcheck_macros = "1.0" reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } diff --git a/axum/src/docs/routing/into_make_service_with_connect_info.md b/axum/src/docs/routing/into_make_service_with_connect_info.md index 26d0602f31..088f21f9d4 100644 --- a/axum/src/docs/routing/into_make_service_with_connect_info.md +++ b/axum/src/docs/routing/into_make_service_with_connect_info.md @@ -35,6 +35,7 @@ use axum::{ serve::IncomingStream, Router, }; +use tokio::net::TcpListener; let app = Router::new().route("/", get(handler)); @@ -49,8 +50,8 @@ struct MyConnectInfo { // ... } -impl Connected> for MyConnectInfo { - fn connect_info(target: IncomingStream<'_>) -> Self { +impl Connected> for MyConnectInfo { + fn connect_info(target: IncomingStream<'_, TcpListener>) -> Self { MyConnectInfo { // ... } diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index f77db6dd44..6b13aa41b0 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -80,16 +80,17 @@ where /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info pub trait Connected: Clone + Send + Sync + 'static { /// Create type holding information about the connection. - fn connect_info(target: T) -> Self; + fn connect_info(stream: T) -> Self; } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; + use tokio::net::TcpListener; - impl Connected> for SocketAddr { - fn connect_info(target: IncomingStream<'_>) -> Self { - target.remote_addr() + impl Connected> for SocketAddr { + fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self { + *stream.remote_addr() } } }; @@ -263,8 +264,8 @@ mod tests { value: &'static str, } - impl Connected> for MyConnectInfo { - fn connect_info(_target: IncomingStream<'_>) -> Self { + impl Connected> for MyConnectInfo { + fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self { Self { value: "it worked!", } diff --git a/axum/src/handler/service.rs b/axum/src/handler/service.rs index e6b8df9316..2090051978 100644 --- a/axum/src/handler/service.rs +++ b/axum/src/handler/service.rs @@ -180,12 +180,13 @@ where // for `axum::serve(listener, handler)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for HandlerService + impl Service> for HandlerService where H: Clone, S: Clone, + L: serve::Listener, { type Response = Self; type Error = Infallible; @@ -195,7 +196,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone())) } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 1eb6075b22..a0712aab38 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -1227,9 +1227,12 @@ where // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for MethodRouter<()> { + impl Service> for MethodRouter<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -1238,7 +1241,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { std::future::ready(Ok(self.clone().with_state(()))) } } diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index dc6ca81591..15bac0e8a7 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -486,9 +486,12 @@ impl Router { // for `axum::serve(listener, router)` #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] const _: () = { - use crate::serve::IncomingStream; + use crate::serve; - impl Service> for Router<()> { + impl Service> for Router<()> + where + L: serve::Listener, + { type Response = Self; type Error = Infallible; type Future = std::future::Ready>; @@ -497,7 +500,7 @@ const _: () = { Poll::Ready(Ok(())) } - fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { + fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future { // call `Router::with_state` such that everything is turned into `Route` eagerly // rather than doing that per request std::future::ready(Ok(self.clone().with_state(()))) diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 1ba9a1452c..c47baf7d1c 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -1,12 +1,12 @@ //! Serve services. use std::{ + any::TypeId, convert::Infallible, fmt::Debug, future::{poll_fn, Future, IntoFuture}, io, marker::PhantomData, - net::SocketAddr, sync::Arc, time::Duration, }; @@ -18,12 +18,59 @@ use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::{ + io::{AsyncRead, AsyncWrite}, net::{TcpListener, TcpStream}, sync::watch, }; use tower::ServiceExt as _; use tower_service::Service; +/// Types that can listen for connections. +pub trait Listener: Send + 'static { + /// The listener's IO type. + type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// The listener's address type. + type Addr: Send; + + /// Accept a new incoming connection to this listener + fn accept(&mut self) -> impl Future> + Send; + + /// Returns the local address that this listener is bound to. + fn local_addr(&self) -> io::Result; +} + +impl Listener for TcpListener { + type Io = TcpStream; + type Addr = std::net::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + +#[cfg(unix)] +impl Listener for tokio::net::UnixListener { + type Io = tokio::net::UnixStream; + type Addr = tokio::net::unix::SocketAddr; + + #[inline] + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + Self::accept(self).await + } + + #[inline] + fn local_addr(&self) -> io::Result { + Self::local_addr(self) + } +} + /// Serve the service with the supplied listener. /// /// This method of running a service is intentionally simple and doesn't support any configuration. @@ -89,14 +136,15 @@ use tower_service::Service; /// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info /// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -pub fn serve(tcp_listener: TcpListener, make_service: M) -> Serve +pub fn serve(listener: L, make_service: M) -> Serve where - M: for<'a> Service, Error = Infallible, Response = S>, + L: Listener, + M: for<'a> Service, Error = Infallible, Response = S>, S: Service + Clone + Send + 'static, S::Future: Send, { Serve { - tcp_listener, + listener, make_service, tcp_nodelay: None, _marker: PhantomData, @@ -106,15 +154,18 @@ where /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct Serve { - tcp_listener: TcpListener, +pub struct Serve { + listener: L, make_service: M, tcp_nodelay: Option, _marker: PhantomData, } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Serve { +impl Serve +where + L: Listener, +{ /// Prepares a server to handle graceful shutdown when the provided future completes. /// /// # Example @@ -136,12 +187,12 @@ impl Serve { /// // ... /// } /// ``` - pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown + pub fn with_graceful_shutdown(self, signal: F) -> WithGracefulShutdown where F: Future + Send + 'static, { WithGracefulShutdown { - tcp_listener: self.tcp_listener, + listener: self.listener, make_service: self.make_service, signal, tcp_nodelay: self.tcp_nodelay, @@ -149,6 +200,14 @@ impl Serve { } } + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl Serve { /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// /// See also [`TcpStream::set_nodelay`]. @@ -173,39 +232,41 @@ impl Serve { ..self } } - - /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() - } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for Serve +impl Debug for Serve where + L: Debug + 'static, M: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, tcp_nodelay, _marker: _, } = self; - f.debug_struct("Serve") - .field("tcp_listener", tcp_listener) - .field("make_service", make_service) - .field("tcp_nodelay", tcp_nodelay) - .finish() + let mut s = f.debug_struct("Serve"); + s.field("listener", listener) + .field("make_service", make_service); + + if TypeId::of::() == TypeId::of::() { + s.field("tcp_nodelay", tcp_nodelay); + } + + s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for Serve +impl IntoFuture for Serve where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, { @@ -221,15 +282,27 @@ where /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[must_use = "futures must be awaited or polled"] -pub struct WithGracefulShutdown { - tcp_listener: TcpListener, +pub struct WithGracefulShutdown { + listener: L, make_service: M, signal: F, tcp_nodelay: Option, _marker: PhantomData, } -impl WithGracefulShutdown { +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown +where + L: Listener, +{ + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } +} + +#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +impl WithGracefulShutdown { /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. /// /// See also [`TcpStream::set_nodelay`]. @@ -259,43 +332,45 @@ impl WithGracefulShutdown { ..self } } - - /// Returns the local address this server is bound to. - pub fn local_addr(&self) -> io::Result { - self.tcp_listener.local_addr() - } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl Debug for WithGracefulShutdown +impl Debug for WithGracefulShutdown where + L: Debug + 'static, M: Debug, S: Debug, F: Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let Self { - tcp_listener, + listener, make_service, signal, tcp_nodelay, _marker: _, } = self; - f.debug_struct("WithGracefulShutdown") - .field("tcp_listener", tcp_listener) + let mut s = f.debug_struct("WithGracefulShutdown"); + s.field("listener", listener) .field("make_service", make_service) - .field("signal", signal) - .field("tcp_nodelay", tcp_nodelay) - .finish() + .field("signal", signal); + + if TypeId::of::() == TypeId::of::() { + s.field("tcp_nodelay", tcp_nodelay); + } + + s.finish() } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] -impl IntoFuture for WithGracefulShutdown +impl IntoFuture for WithGracefulShutdown where - M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, - for<'a> >>::Future: Send, + L: Listener, + L::Addr: Debug, + M: for<'a> Service, Error = Infallible, Response = S> + Send + 'static, + for<'a> >>::Future: Send, S: Service + Clone + Send + 'static, S::Future: Send, F: Future + Send + 'static, @@ -305,7 +380,7 @@ where fn into_future(self) -> Self::IntoFuture { let Self { - tcp_listener, + mut listener, mut make_service, signal, tcp_nodelay, @@ -324,8 +399,8 @@ where private::ServeFuture(Box::pin(async move { loop { - let (tcp_stream, remote_addr) = tokio::select! { - conn = tcp_accept(&tcp_listener) => { + let (io, remote_addr) = tokio::select! { + conn = accept(&mut listener) => { match conn { Some(conn) => conn, None => continue, @@ -338,14 +413,16 @@ where }; if let Some(nodelay) = tcp_nodelay { + let tcp_stream: &tokio::net::TcpStream = ::downcast_ref(&io) + .expect("internal error: tcp_nodelay used with the wrong type of listener"); if let Err(err) = tcp_stream.set_nodelay(nodelay) { trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); } } - let tcp_stream = TokioIo::new(tcp_stream); + let io = TokioIo::new(io); - trace!("connection {remote_addr} accepted"); + trace!("connection {remote_addr:?} accepted"); poll_fn(|cx| make_service.poll_ready(cx)) .await @@ -353,7 +430,7 @@ where let tower_service = make_service .call(IncomingStream { - tcp_stream: &tcp_stream, + io: &io, remote_addr, }) .await @@ -368,7 +445,7 @@ where tokio::spawn(async move { let builder = Builder::new(TokioExecutor::new()); - let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); + let conn = builder.serve_connection_with_upgrades(io, hyper_service); pin_mut!(conn); let signal_closed = signal_tx.closed().fuse(); @@ -389,14 +466,12 @@ where } } - trace!("connection {remote_addr} closed"); - drop(close_rx); }); } drop(close_rx); - drop(tcp_listener); + drop(listener); trace!( "waiting for {} task(s) to finish", @@ -418,7 +493,10 @@ fn is_connection_error(e: &io::Error) -> bool { ) } -async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { +async fn accept(listener: &mut L) -> Option<(L::Io, L::Addr)> +where + L: Listener, +{ match listener.accept().await { Ok(conn) => Some(conn), Err(e) => { @@ -444,6 +522,35 @@ async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> { } } +/// An incoming stream. +/// +/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. +/// +/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo +#[derive(Debug)] +pub struct IncomingStream<'a, L> +where + L: Listener, +{ + io: &'a TokioIo, + remote_addr: L::Addr, +} + +impl IncomingStream<'_, L> +where + L: Listener, +{ + /// Get a reference to the inner IO type. + pub fn io(&self) -> &L::Io { + self.io.inner() + } + + /// Returns the remote address that this stream is bound to. + pub fn remote_addr(&self) -> &L::Addr { + &self.remote_addr + } +} + mod private { use std::{ future::Future, @@ -470,33 +577,15 @@ mod private { } } -/// An incoming stream. -/// -/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. -/// -/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo -#[derive(Debug)] -pub struct IncomingStream<'a> { - tcp_stream: &'a TokioIo, - remote_addr: SocketAddr, -} - -impl IncomingStream<'_> { - /// Returns the local address that this stream is bound to. - pub fn local_addr(&self) -> std::io::Result { - self.tcp_stream.inner().local_addr() - } - - /// Returns the remote address that this stream is bound to. - pub fn remote_addr(&self) -> SocketAddr { - self.remote_addr - } -} - #[cfg(test)] mod tests { + use http::StatusCode; + use tokio::net::UnixListener; + use super::*; use crate::{ + body::to_bytes, + extract::connect_info::Connected, handler::{Handler, HandlerWithoutStateExt}, routing::get, Router, @@ -508,30 +597,63 @@ mod tests { #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { + #[derive(Clone, Debug)] + struct UdsConnectInfo; + + impl Connected> for UdsConnectInfo { + fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self { + Self + } + } + let router: Router = Router::new(); let addr = "0.0.0.0:0"; // router serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + serve(UnixListener::bind("").unwrap(), router.clone()); + serve( TcpListener::bind(addr).await.unwrap(), router.clone().into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + router.clone().into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - router.into_make_service_with_connect_info::(), + router + .clone() + .into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + router.into_make_service_with_connect_info::(), ); // method router serve(TcpListener::bind(addr).await.unwrap(), get(handler)); + serve(UnixListener::bind("").unwrap(), get(handler)); + serve( TcpListener::bind(addr).await.unwrap(), get(handler).into_make_service(), ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service(), + ); + serve( TcpListener::bind(addr).await.unwrap(), - get(handler).into_make_service_with_connect_info::(), + get(handler).into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + get(handler).into_make_service_with_connect_info::(), ); // handler @@ -539,17 +661,27 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_service()); + serve( TcpListener::bind(addr).await.unwrap(), handler.with_state(()), ); + serve(UnixListener::bind("").unwrap(), handler.with_state(())); + serve( TcpListener::bind(addr).await.unwrap(), handler.into_make_service(), ); + serve(UnixListener::bind("").unwrap(), handler.into_make_service()); + serve( TcpListener::bind(addr).await.unwrap(), - handler.into_make_service_with_connect_info::(), + handler.into_make_service_with_connect_info::(), + ); + serve( + UnixListener::bind("").unwrap(), + handler.into_make_service_with_connect_info::(), ); // nodelay @@ -593,4 +725,49 @@ mod tests { assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); assert_ne!(address.port(), 0); } + + #[crate::test] + async fn serving_on_custom_io_type() { + struct ReadyListener(Option); + + impl Listener for ReadyListener + where + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Io = T; + type Addr = (); + + async fn accept(&mut self) -> io::Result<(Self::Io, Self::Addr)> { + match self.0.take() { + Some(server) => Ok((server, ())), + None => std::future::pending().await, + } + } + + fn local_addr(&self) -> io::Result { + Ok(()) + } + } + + let (client, server) = tokio::io::duplex(1024); + let listener = ReadyListener(Some(server)); + + let app = Router::new().route("/", get(|| async { "Hello, World!" })); + + tokio::spawn(serve(listener, app).into_future()); + + let stream = TokioIo::new(client); + let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap(); + tokio::spawn(conn); + + let request = Request::builder().body(Body::empty()).unwrap(); + + let response = sender.send_request(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = Body::new(response.into_body()); + let body = to_bytes(body, usize::MAX).await.unwrap(); + let body = String::from_utf8(body.to_vec()).unwrap(); + assert_eq!(body, "Hello, World!"); + } } diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index fbb4c3b067..b60bc8211f 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -23,17 +23,13 @@ mod unix { extract::connect_info::{self, ConnectInfo}, http::{Method, Request, StatusCode}, routing::get, + serve::IncomingStream, Router, }; use http_body_util::BodyExt; - use hyper::body::Incoming; - use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, - }; - use std::{convert::Infallible, path::PathBuf, sync::Arc}; + use hyper_util::rt::TokioIo; + use std::{path::PathBuf, sync::Arc}; use tokio::net::{unix::UCred, UnixListener, UnixStream}; - use tower::Service; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub async fn server() { @@ -54,33 +50,11 @@ mod unix { let uds = UnixListener::bind(path.clone()).unwrap(); tokio::spawn(async move { - let app = Router::new().route("/", get(handler)); - - let mut make_service = app.into_make_service_with_connect_info::(); - - // See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for - // more details about this setup - loop { - let (socket, _remote_addr) = uds.accept().await.unwrap(); - - let tower_service = unwrap_infallible(make_service.call(&socket).await); - - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - let hyper_service = - hyper::service::service_fn(move |request: Request| { - tower_service.clone().call(request) - }); + let app = Router::new() + .route("/", get(handler)) + .into_make_service_with_connect_info::(); - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } + axum::serve(uds, app).await.unwrap(); }); let stream = TokioIo::new(UnixStream::connect(path).await.unwrap()); @@ -119,22 +93,14 @@ mod unix { peer_cred: UCred, } - impl connect_info::Connected<&UnixStream> for UdsConnectInfo { - fn connect_info(target: &UnixStream) -> Self { - let peer_addr = target.peer_addr().unwrap(); - let peer_cred = target.peer_cred().unwrap(); - + impl connect_info::Connected> for UdsConnectInfo { + fn connect_info(stream: IncomingStream<'_, UnixListener>) -> Self { + let peer_addr = stream.io().peer_addr().unwrap(); + let peer_cred = stream.io().peer_cred().unwrap(); Self { peer_addr: Arc::new(peer_addr), peer_cred, } } } - - fn unwrap_infallible(result: Result) -> T { - match result { - Ok(value) => value, - Err(err) => match err {}, - } - } }