From dee2ab52ff4a2995156a3baf5ea916b479fd1d14 Mon Sep 17 00:00:00 2001 From: Anthony Green Date: Mon, 14 Feb 2022 15:00:58 -0500 Subject: [PATCH] feat(transport): add unix socket support in server (#861) --- examples/src/uds/server.rs | 85 ++----------------- tests/integration_tests/tests/connect_info.rs | 74 ++++++++++++++++ tonic/src/request.rs | 4 +- tonic/src/transport/server/mod.rs | 5 ++ tonic/src/transport/server/unix.rs | 31 +++++++ 5 files changed, 120 insertions(+), 79 deletions(-) create mode 100644 tonic/src/transport/server/unix.rs diff --git a/examples/src/uds/server.rs b/examples/src/uds/server.rs index 8ffe4fa0a..b4ca31002 100644 --- a/examples/src/uds/server.rs +++ b/examples/src/uds/server.rs @@ -1,9 +1,12 @@ #![cfg_attr(not(unix), allow(unused_imports))] -use futures::TryFutureExt; use std::path::Path; #[cfg(unix)] use tokio::net::UnixListener; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; +#[cfg(unix)] +use tonic::transport::server::UdsConnectInfo; use tonic::{transport::Server, Request, Response, Status}; pub mod hello_world { @@ -26,7 +29,7 @@ impl Greeter for MyGreeter { ) -> Result, Status> { #[cfg(unix)] { - let conn_info = request.extensions().get::().unwrap(); + let conn_info = request.extensions().get::().unwrap(); println!("Got a request {:?} with info {:?}", request, conn_info); } @@ -46,89 +49,17 @@ async fn main() -> Result<(), Box> { let greeter = MyGreeter::default(); - let incoming = { - let uds = UnixListener::bind(path)?; - - async_stream::stream! { - loop { - let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await; - - yield item; - } - } - }; + let uds = UnixListener::bind(path)?; + let uds_stream = UnixListenerStream::new(uds); Server::builder() .add_service(GreeterServer::new(greeter)) - .serve_with_incoming(incoming) + .serve_with_incoming(uds_stream) .await?; Ok(()) } -#[cfg(unix)] -mod unix { - use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; - - use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - use tonic::transport::server::Connected; - - #[derive(Debug)] - pub struct UnixStream(pub tokio::net::UnixStream); - - impl Connected for UnixStream { - type ConnectInfo = UdsConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - UdsConnectInfo { - peer_addr: self.0.peer_addr().ok().map(Arc::new), - peer_cred: self.0.peer_cred().ok(), - } - } - } - - #[derive(Clone, Debug)] - pub struct UdsConnectInfo { - pub peer_addr: Option>, - pub peer_cred: Option, - } - - impl AsyncRead for UnixStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } - } - - impl AsyncWrite for UnixStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } - } -} - #[cfg(not(unix))] fn main() { panic!("The `uds` example only works on unix systems!"); diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 936eedac1..cc57f6b6b 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -48,3 +48,77 @@ async fn getting_connect_info() { jh.await.unwrap(); } + +#[cfg(unix)] +pub mod unix { + use std::convert::TryFrom as _; + + use futures_util::FutureExt; + use tokio::{ + net::{UnixListener, UnixStream}, + sync::oneshot, + }; + use tokio_stream::wrappers::UnixListenerStream; + use tonic::{ + transport::{server::UdsConnectInfo, Endpoint, Server, Uri}, + Request, Response, Status, + }; + use tower::service_fn; + + use integration_tests::pb::{test_client, test_server, Input, Output}; + + struct Svc {} + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + let conn_info = req.extensions().get::().unwrap(); + + // Client-side unix sockets are unnamed. + assert!(req.remote_addr().is_none()); + assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed()); + // This should contain process credentials for the client socket. + assert!(conn_info.peer_cred.as_ref().is_some()); + + Ok(Response::new(Output {})) + } + } + + #[tokio::test] + async fn getting_connect_info() { + let mut unix_socket_path = std::env::temp_dir(); + unix_socket_path.push("uds-integration-test"); + + let uds = UnixListener::bind(&unix_socket_path).unwrap(); + let uds_stream = UnixListenerStream::new(uds); + + let service = test_server::TestServer::new(Svc {}); + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(service) + .serve_with_incoming_shutdown(uds_stream, rx.map(drop)) + .await + .unwrap(); + }); + + // Take a copy before moving into the `service_fn` closure so that the closure + // can implement `FnMut`. + let path = unix_socket_path.clone(); + let channel = Endpoint::try_from("http://[::]:50051") + .unwrap() + .connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone()))) + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + jh.await.unwrap(); + + std::fs::remove_file(unix_socket_path).unwrap(); + } +} diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 46a2d486d..64dd042cf 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -202,8 +202,8 @@ impl Request { /// Get the remote address of this connection. /// /// This will return `None` if the `IO` type used - /// does not implement `Connected`. This currently, - /// only works on the server side. + /// does not implement `Connected` or when using a unix domain socket. + /// This currently only works on the server side. pub fn remote_addr(&self) -> Option { #[cfg(feature = "transport")] { diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index adba6f896..6a391cd62 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -6,6 +6,8 @@ mod recover_error; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; +#[cfg(unix)] +mod unix; pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] @@ -17,6 +19,9 @@ pub use conn::TlsConnectInfo; #[cfg(feature = "tls")] use super::service::TlsAcceptor; +#[cfg(unix)] +pub use unix::UdsConnectInfo; + use incoming::TcpIncoming; #[cfg(feature = "tls")] diff --git a/tonic/src/transport/server/unix.rs b/tonic/src/transport/server/unix.rs new file mode 100644 index 000000000..31454b7d1 --- /dev/null +++ b/tonic/src/transport/server/unix.rs @@ -0,0 +1,31 @@ +use super::Connected; +use std::sync::Arc; + +/// Connection info for Unix domain socket streams. +/// +/// This type will be accessible through [request extensions][ext] if you're using +/// a unix stream. +/// +/// See [Connected] for more details. +/// +/// [ext]: crate::Request::extensions +/// [Connected]: crate::transport::server::Connected +#[cfg_attr(docsrs, doc(cfg(unix)))] +#[derive(Clone, Debug)] +pub struct UdsConnectInfo { + /// Peer address. This will be "unnamed" for client unix sockets. + pub peer_addr: Option>, + /// Process credentials for the unix socket. + pub peer_cred: Option, +} + +impl Connected for tokio::net::UnixStream { + type ConnectInfo = UdsConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + UdsConnectInfo { + peer_addr: self.peer_addr().ok().map(Arc::new), + peer_cred: self.peer_cred().ok(), + } + } +}