Skip to content

Commit

Permalink
chore(server): Refactor TcpIncoming (hyperium#2052)
Browse files Browse the repository at this point in the history
  • Loading branch information
tottoto authored Nov 26, 2024
1 parent 3c0a00d commit bdccf58
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 70 deletions.
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/client_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn connect_supports_standard_tower_layers() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async fn getting_connect_info() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async fn connect_returns_err_via_call_after_connected() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -85,7 +85,7 @@ async fn connect_lazy_reconnects_after_first_failure() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn setting_extension_from_interceptor() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -90,7 +90,7 @@ async fn setting_extension_from_tower() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/tests/http2_keep_alive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async fn http2_keepalive_does_not_cause_panics() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -52,7 +52,7 @@ async fn http2_keepalive_does_not_cause_panics_on_client_side() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/tests/http2_max_header_list_size.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ async fn test_http_max_header_list_size_and_long_errors() {
let addr = format!("http://{}", listener.local_addr().unwrap());

let jh = tokio::spawn(async move {
let (nodelay, keepalive) = (true, Some(Duration::from_secs(1)));
let listener =
tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive)
.unwrap();
let (nodelay, keepalive) = (Some(true), Some(Duration::from_secs(1)));
let listener = tonic::transport::server::TcpIncoming::from(listener)
.with_nodelay(nodelay)
.with_keepalive(keepalive);
Server::builder()
.http2_max_pending_accept_reset_streams(Some(0))
.add_service(svc)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/interceptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn interceptor_retrieves_grpc_method() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

// Start the server now, second call should succeed
let jh = tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/origin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async fn writes_origin_header() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/routes_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async fn multiple_service_using_routes_builder() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/tests/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn status_with_details() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -94,7 +94,7 @@ async fn status_with_metadata() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -165,7 +165,7 @@ async fn status_from_server_stream() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

tokio::spawn(async move {
Server::builder()
Expand Down Expand Up @@ -235,7 +235,7 @@ async fn message_and_then_status_from_server_stream() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

tokio::spawn(async move {
Server::builder()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/tests/user_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn writes_user_agent_header() {

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
Expand Down
94 changes: 48 additions & 46 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
net::{SocketAddr, TcpListener as StdTcpListener},
pin::Pin,
task::{ready, Context, Poll},
task::{Context, Poll},
time::Duration,
};

use socket2::TcpKeepalive;
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::{wrappers::TcpListenerStream, Stream};
use tracing::warn;
Expand All @@ -16,13 +17,13 @@ use tracing::warn;
#[derive(Debug)]
pub struct TcpIncoming {
inner: TcpListenerStream,
nodelay: bool,
keepalive: Option<Duration>,
nodelay: Option<bool>,
keepalive: Option<TcpKeepalive>,
}

impl TcpIncoming {
/// Creates an instance by binding (opening) the specified socket address
/// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
/// Creates an instance by binding (opening) the specified socket address.
///
/// Returns a TcpIncoming if the socket address was successfully bound.
///
/// # Examples
Expand All @@ -42,7 +43,7 @@ impl TcpIncoming {
/// let mut port = 1322;
/// let tinc = loop {
/// let addr = format!("127.0.0.1:{}", port).parse().unwrap();
/// match TcpIncoming::new(addr, true, None) {
/// match TcpIncoming::bind(addr) {
/// Ok(t) => break t,
/// Err(_) => port += 1
/// }
Expand All @@ -52,64 +53,65 @@ impl TcpIncoming {
/// .serve_with_incoming(tinc);
/// # Ok(())
/// # }
pub fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
let std_listener = StdTcpListener::bind(addr)?;
std_listener.set_nonblocking(true)?;

let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
Ok(Self {
inner,
nodelay,
keepalive,
})
Ok(TcpListener::from_std(std_listener)?.into())
}

/// Sets the `TCP_NODELAY` option on the accepted connection.
pub fn with_nodelay(self, nodelay: Option<bool>) -> Self {
Self { nodelay, ..self }
}

/// Sets the `TCP_KEEPALIVE` option on the accepted connection.
pub fn with_keepalive(self, keepalive: Option<Duration>) -> Self {
let keepalive = keepalive.map(|t| TcpKeepalive::new().with_time(t));
Self { keepalive, ..self }
}
}

/// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
pub fn from_listener(
listener: TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::BoxError> {
Ok(Self {
impl From<TcpListener> for TcpIncoming {
fn from(listener: TcpListener) -> Self {
Self {
inner: TcpListenerStream::new(listener),
nodelay,
keepalive,
})
nodelay: None,
keepalive: None,
}
}
}

impl Stream for TcpIncoming {
type Item = Result<TcpStream, std::io::Error>;
type Item = std::io::Result<TcpStream>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(stream)) => {
set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
Some(Ok(stream)).into()
}
other => Poll::Ready(other),
let polled = Pin::new(&mut self.inner).poll_next(cx);

if let Poll::Ready(Some(Ok(stream))) = &polled {
set_accepted_socket_options(stream, self.nodelay, &self.keepalive);
}

polled
}
}

// Consistent with hyper-0.14, this function does not return an error.
fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
if nodelay {
if let Err(e) = stream.set_nodelay(true) {
warn!("error trying to set TCP nodelay: {}", e);
fn set_accepted_socket_options(
stream: &TcpStream,
nodelay: Option<bool>,
keepalive: &Option<TcpKeepalive>,
) {
if let Some(nodelay) = nodelay {
if let Err(e) = stream.set_nodelay(nodelay) {
warn!("error trying to set TCP_NODELAY: {e}");
}
}

if let Some(timeout) = keepalive {
if let Some(keepalive) = keepalive {
let sock_ref = socket2::SockRef::from(&stream);
let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);

if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
warn!("error trying to set TCP keepalive: {}", e);
if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) {
warn!("error trying to set TCP_KEEPALIVE: {e}");
}
}
}
Expand All @@ -121,9 +123,9 @@ mod tests {
async fn one_tcpincoming_at_a_time() {
let addr = "127.0.0.1:1322".parse().unwrap();
{
let _t1 = TcpIncoming::new(addr, true, None).unwrap();
let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
let _t1 = TcpIncoming::bind(addr).unwrap();
let _t2 = TcpIncoming::bind(addr).unwrap_err();
}
let _t3 = TcpIncoming::new(addr, true, None).unwrap();
let _t3 = TcpIncoming::bind(addr).unwrap();
}
}
12 changes: 8 additions & 4 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::bind(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
self.routes.prepare(),
Expand Down Expand Up @@ -809,8 +811,10 @@ impl<L> Router<L> {
ResBody: http_body::Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError>,
{
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(super::Error::from_source)?;
let incoming = TcpIncoming::bind(addr)
.map_err(super::Error::from_source)?
.with_nodelay(Some(self.server.tcp_nodelay))
.with_keepalive(self.server.tcp_keepalive);
self.server
.serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
.await
Expand Down

0 comments on commit bdccf58

Please sign in to comment.