Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions tests/integration_tests/tests/connection.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use http::{header::HeaderName, HeaderValue};
use integration_tests::pb::{test_client::TestClient, test_server, Input, Output};
use std::sync::{Arc, Mutex};
use std::time::Duration;
Expand All @@ -6,6 +7,7 @@ use tonic::{
transport::{server::TcpIncoming, Endpoint, Server},
Code, Request, Response, Status,
};
use tower_http::set_header::SetRequestHeaderLayer;

struct Svc(Arc<Mutex<Option<oneshot::Sender<()>>>>);

Expand Down Expand Up @@ -69,6 +71,63 @@ async fn connect_returns_err_via_call_after_connected() {
jh.await.unwrap();
}

#[tokio::test]
async fn endpoint_layer_stacks_and_applies_to_requests() {
struct HeaderSvc;

#[tonic::async_trait]
impl test_server::Test for HeaderSvc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
// Both layers must have been applied for this to succeed, proving the
// layers stack rather than the second `layer` call replacing the first.
match (
req.metadata().get("x-first"),
req.metadata().get("x-second"),
) {
(Some(_), Some(_)) => Ok(Response::new(Output {})),
_ => Err(Status::internal("a header set by a layer is missing")),
}
}
}

let (tx, rx) = oneshot::channel();
let svc = test_server::TestServer::new(HeaderSvc);

let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let incoming = TcpIncoming::from(listener).with_nodelay(Some(true));

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_incoming_shutdown(incoming, async { drop(rx.await) })
.await
.unwrap();
});

// The layers are configured directly on the `Endpoint`, so no type
// parameters leak onto the resulting `Channel`.
let channel = Endpoint::from_shared(format!("http://{addr}"))
.unwrap()
.layer(SetRequestHeaderLayer::overriding(
HeaderName::from_static("x-first"),
HeaderValue::from_static("first"),
))
.layer(SetRequestHeaderLayer::overriding(
HeaderName::from_static("x-second"),
HeaderValue::from_static("second"),
))
.connect_lazy();

let mut client = TestClient::new(channel);

tokio::time::sleep(Duration::from_millis(100)).await;
client.unary_call(Request::new(Input {})).await.unwrap();

tx.send(()).unwrap();
jh.await.unwrap();
}

#[tokio::test]
async fn connect_lazy_reconnects_after_first_failure() {
let (tx, rx) = oneshot::channel();
Expand Down
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ allowed_external_types = [
"tower_layer::Layer",
"tower_layer::stack::Stack",
"tower_layer::identity::Identity",
"tower::util::boxed::sync::BoxService",
]

[[bench]]
Expand Down
43 changes: 42 additions & 1 deletion tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,35 @@
use super::service::TlsConnector;
use super::service::{self, Executor, SharedExec};
use super::uds_connector::UdsConnector;
use crate::body::Body;
use crate::transport::Error;
#[cfg(feature = "_tls-any")]
use crate::transport::error;
use bytes::Bytes;
use http::{HeaderValue, uri::Uri};
use http::{HeaderValue, Request, Response, uri::Uri};
use hyper::rt;
use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(feature = "_tls-any")]
use std::sync::Arc;
use std::{fmt, future::Future, net::IpAddr, pin::Pin, str, str::FromStr, time::Duration};
#[cfg(feature = "_tls-any")]
use tokio_rustls::rustls::client::danger::ServerCertVerifier;
use tower::layer::Layer;
use tower::layer::util::Stack;
use tower::util::{BoxLayer, BoxService};
use tower_service::Service;

/// A boxed [`Layer`] applied to the [`Connection`](super::service::Connection) service.
///
/// The layer wraps the boxed connection service, allowing arbitrary `tower`
/// middleware to be added to an [`Endpoint`] without leaking type parameters.
pub(crate) type BoxedLayer = BoxLayer<
BoxService<Request<Body>, Response<Body>, crate::BoxError>,
Request<Body>,
Response<Body>,
crate::BoxError,
>;

#[derive(Clone, PartialEq, Eq, Hash)]
pub(crate) enum EndpointType {
Uri(Uri),
Expand Down Expand Up @@ -56,6 +71,7 @@
pub(crate) http2_adaptive_window: Option<bool>,
pub(crate) local_address: Option<IpAddr>,
pub(crate) executor: SharedExec,
pub(crate) layer: Option<BoxedLayer>,
}

impl Endpoint {
Expand All @@ -69,7 +85,7 @@
{
let me = dst.try_into().map_err(|e| Error::from_source(e.into()))?;
#[cfg(feature = "_tls-any")]
if let EndpointType::Uri(uri) = &me.uri {

Check warning on line 88 in tonic/src/transport/channel/endpoint.rs

View workflow job for this annotation

GitHub Actions / clippy

this `if` statement can be collapsed
if me.tls.is_none() && uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
return me.tls_config(ClientTlsConfig::new().with_enabled_roots());
}
Expand Down Expand Up @@ -105,6 +121,7 @@
http2_adaptive_window: None,
executor: SharedExec::tokio(),
local_address: None,
layer: None,
}
}

Expand Down Expand Up @@ -136,6 +153,7 @@
http2_adaptive_window: None,
executor: SharedExec::tokio(),
local_address: None,
layer: None,
}
}

Expand Down Expand Up @@ -519,6 +537,29 @@
self
}

/// Add a [`tower::Layer`] to the stack of middleware applied to each
/// connection.
///
/// Calling this method multiple times stacks the layers, with the most
/// recently added layer wrapping the previously added ones.
pub fn layer<L>(mut self, layer: L) -> Self
where
L: Layer<BoxService<Request<Body>, Response<Body>, crate::BoxError>>
+ Send
+ Sync
+ 'static,
L::Service: Service<Request<Body>, Response = Response<Body>, Error = crate::BoxError>
+ Send
+ 'static,
<L::Service as Service<Request<Body>>>::Future: Send + 'static,
{
self.layer = Some(match self.layer.take() {
Some(existing) => BoxLayer::new(Stack::new(existing, layer)),
None => BoxLayer::new(layer),
});
self
}

pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
service::Connector::new(
c,
Expand Down
10 changes: 7 additions & 3 deletions tonic/src/transport/channel/service/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ impl Connection {

let conn = Reconnect::new(make_service, endpoint.uri().clone(), is_lazy);

Self {
inner: BoxService::new(stack.layer(conn)),
}
let inner = BoxService::new(stack.layer(conn));
let inner = match &endpoint.layer {
Some(layer) => layer.layer(inner),
None => inner,
};

Self { inner }
}

pub(crate) async fn connect<C>(
Expand Down
Loading