Skip to content

Commit 1b7e3ed

Browse files
committed
add support for tower's load-shed layer
Refs: #1616
1 parent 13b9643 commit 1b7e3ed

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

tonic/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ server = [
3939
"dep:socket2",
4040
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
4141
"tokio-stream/net",
42-
"dep:tower", "tower?/util", "tower?/limit",
42+
"dep:tower", "tower?/util", "tower?/limit", "tower?/load-shed",
4343
]
4444
channel = [
4545
"dep:hyper", "hyper?/client",
4646
"dep:hyper-util", "hyper-util?/client-legacy",
47-
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
47+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/load-shed", "tower?/util",
4848
"dep:tokio", "tokio?/time",
4949
"dep:hyper-timeout",
5050
]

tonic/src/status.rs

+12
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@ impl Status {
348348
Err(err) => err,
349349
};
350350

351+
// If the load shed middleware is enabled, respond to
352+
// service overloaded with an appropriate grpc status.
353+
#[cfg(feature = "server")]
354+
let err = match err.downcast::<tower::load_shed::error::Overloaded>() {
355+
Ok(_) => {
356+
return Ok(Status::resource_exhausted(
357+
"Too many active requests for the connection",
358+
));
359+
}
360+
Err(err) => err,
361+
};
362+
351363
if let Some(mut status) = find_status_in_source_chain(&*err) {
352364
status.source = Some(err.into());
353365
return Ok(status);

tonic/src/transport/server/mod.rs

+29
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use tower::{
6666
layer::util::{Identity, Stack},
6767
layer::Layer,
6868
limit::concurrency::ConcurrencyLimitLayer,
69+
load_shed::LoadShedLayer,
6970
util::BoxCloneService,
7071
Service, ServiceBuilder, ServiceExt,
7172
};
@@ -87,6 +88,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
8788
pub struct Server<L = Identity> {
8889
trace_interceptor: Option<TraceInterceptor>,
8990
concurrency_limit: Option<usize>,
91+
load_shed: bool,
9092
timeout: Option<Duration>,
9193
#[cfg(feature = "_tls-any")]
9294
tls: Option<TlsAcceptor>,
@@ -111,6 +113,7 @@ impl Default for Server<Identity> {
111113
Self {
112114
trace_interceptor: None,
113115
concurrency_limit: None,
116+
load_shed: false,
114117
timeout: None,
115118
#[cfg(feature = "_tls-any")]
116119
tls: None,
@@ -179,6 +182,27 @@ impl<L> Server<L> {
179182
}
180183
}
181184

185+
/// Enable or disable load shedding. The default is disabled.
186+
///
187+
/// When load shedding is enabled, if the service responds with not ready
188+
/// the request will immediately be rejected with a
189+
/// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error.
190+
/// The default is to buffer requests. This is especially useful in combination with
191+
/// setting a concurrency limit per connection.
192+
///
193+
/// # Example
194+
///
195+
/// ```
196+
/// # use tonic::transport::Server;
197+
/// # use tower_service::Service;
198+
/// # let builder = Server::builder();
199+
/// builder.load_shed(true);
200+
/// ```
201+
#[must_use]
202+
pub fn load_shed(self, load_shed: bool) -> Self {
203+
Server { load_shed, ..self }
204+
}
205+
182206
/// Set a timeout on for all request handlers.
183207
///
184208
/// # Example
@@ -514,6 +538,7 @@ impl<L> Server<L> {
514538
service_builder: self.service_builder.layer(new_layer),
515539
trace_interceptor: self.trace_interceptor,
516540
concurrency_limit: self.concurrency_limit,
541+
load_shed: self.load_shed,
517542
timeout: self.timeout,
518543
#[cfg(feature = "_tls-any")]
519544
tls: self.tls,
@@ -643,6 +668,7 @@ impl<L> Server<L> {
643668
{
644669
let trace_interceptor = self.trace_interceptor.clone();
645670
let concurrency_limit = self.concurrency_limit;
671+
let load_shed = self.load_shed;
646672
let init_connection_window_size = self.init_connection_window_size;
647673
let init_stream_window_size = self.init_stream_window_size;
648674
let max_concurrent_streams = self.max_concurrent_streams;
@@ -667,6 +693,7 @@ impl<L> Server<L> {
667693
let mut svc = MakeSvc {
668694
inner: svc,
669695
concurrency_limit,
696+
load_shed,
670697
timeout,
671698
trace_interceptor,
672699
_io: PhantomData,
@@ -1051,6 +1078,7 @@ impl<S> fmt::Debug for Svc<S> {
10511078
#[derive(Clone)]
10521079
struct MakeSvc<S, IO> {
10531080
concurrency_limit: Option<usize>,
1081+
load_shed: bool,
10541082
timeout: Option<Duration>,
10551083
inner: S,
10561084
trace_interceptor: Option<TraceInterceptor>,
@@ -1084,6 +1112,7 @@ where
10841112

10851113
let svc = ServiceBuilder::new()
10861114
.layer(RecoverErrorLayer::new())
1115+
.option_layer(self.load_shed.then_some(LoadShedLayer::new()))
10871116
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
10881117
.layer_fn(|s| GrpcTimeout::new(s, timeout))
10891118
.service(svc);

0 commit comments

Comments
 (0)