Skip to content

Commit

Permalink
feat(transport): add unix socket support in server (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
agreen17 authored Feb 14, 2022
1 parent d6c0fc1 commit dee2ab5
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 79 deletions.
85 changes: 8 additions & 77 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#![cfg_attr(not(unix), allow(unused_imports))]

use futures::TryFutureExt;
use std::path::Path;
#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
#[cfg(unix)]
use tonic::transport::server::UdsConnectInfo;
use tonic::{transport::Server, Request, Response, Status};

pub mod hello_world {
Expand All @@ -26,7 +29,7 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);
}

Expand All @@ -46,89 +49,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let greeter = MyGreeter::default();

let incoming = {
let uds = UnixListener::bind(path)?;

async_stream::stream! {
loop {
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;

yield item;
}
}
};
let uds = UnixListener::bind(path)?;
let uds_stream = UnixListenerStream::new(uds);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(incoming)
.serve_with_incoming(uds_stream)
.await?;

Ok(())
}

#[cfg(unix)]
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::server::Connected;

#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}

#[cfg(not(unix))]
fn main() {
panic!("The `uds` example only works on unix systems!");
Expand Down
74 changes: 74 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,77 @@ async fn getting_connect_info() {

jh.await.unwrap();
}

#[cfg(unix)]
pub mod unix {
use std::convert::TryFrom as _;

use futures_util::FutureExt;
use tokio::{
net::{UnixListener, UnixStream},
sync::oneshot,
};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Endpoint, Server, Uri},
Request, Response, Status,
};
use tower::service_fn;

use integration_tests::pb::{test_client, test_server, Input, Output};

struct Svc {}

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let conn_info = req.extensions().get::<UdsConnectInfo>().unwrap();

// Client-side unix sockets are unnamed.
assert!(req.remote_addr().is_none());
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());

Ok(Response::new(Output {}))
}
}

#[tokio::test]
async fn getting_connect_info() {
let mut unix_socket_path = std::env::temp_dir();
unix_socket_path.push("uds-integration-test");

let uds = UnixListener::bind(&unix_socket_path).unwrap();
let uds_stream = UnixListenerStream::new(uds);

let service = test_server::TestServer::new(Svc {});
let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(service)
.serve_with_incoming_shutdown(uds_stream, rx.map(drop))
.await
.unwrap();
});

// Take a copy before moving into the `service_fn` closure so that the closure
// can implement `FnMut`.
let path = unix_socket_path.clone();
let channel = Endpoint::try_from("http://[::]:50051")
.unwrap()
.connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone())))
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();

std::fs::remove_file(unix_socket_path).unwrap();
}
}
4 changes: 2 additions & 2 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ impl<T> Request<T> {
/// Get the remote address of this connection.
///
/// This will return `None` if the `IO` type used
/// does not implement `Connected`. This currently,
/// only works on the server side.
/// does not implement `Connected` or when using a unix domain socket.
/// This currently only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
#[cfg(feature = "transport")]
{
Expand Down
5 changes: 5 additions & 0 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ mod recover_error;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
#[cfg(unix)]
mod unix;

pub use conn::{Connected, TcpConnectInfo};
#[cfg(feature = "tls")]
Expand All @@ -17,6 +19,9 @@ pub use conn::TlsConnectInfo;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

#[cfg(unix)]
pub use unix::UdsConnectInfo;

use incoming::TcpIncoming;

#[cfg(feature = "tls")]
Expand Down
31 changes: 31 additions & 0 deletions tonic/src/transport/server/unix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use super::Connected;
use std::sync::Arc;

/// Connection info for Unix domain socket streams.
///
/// This type will be accessible through [request extensions][ext] if you're using
/// a unix stream.
///
/// See [Connected] for more details.
///
/// [ext]: crate::Request::extensions
/// [Connected]: crate::transport::server::Connected
#[cfg_attr(docsrs, doc(cfg(unix)))]
#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
/// Peer address. This will be "unnamed" for client unix sockets.
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
/// Process credentials for the unix socket.
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl Connected for tokio::net::UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.peer_addr().ok().map(Arc::new),
peer_cred: self.peer_cred().ok(),
}
}
}

0 comments on commit dee2ab5

Please sign in to comment.