Skip to content

Commit e043e5b

Browse files
committed
add support for tower load-shed
1 parent fc940ce commit e043e5b

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

tonic/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ server = [
4848
channel = [
4949
"dep:hyper", "hyper?/client",
5050
"dep:hyper-util", "hyper-util?/client-legacy",
51-
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/util",
51+
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/limit", "tower?/load-shed", "tower?/util",
5252
"dep:tokio", "tokio?/time",
5353
"dep:hyper-timeout",
5454
]

tonic/src/status.rs

+11
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,17 @@ impl Status {
348348
Err(err) => err,
349349
};
350350

351+
// If the load shed middleware is enabled, respond to
352+
// service overloaded with an appropriate grpc status.
353+
let err = match err.downcast::<tower::load_shed::error::Overloaded>() {
354+
Ok(_) => {
355+
return Ok(Status::resource_exhausted(
356+
"Too many active requests for the connection",
357+
));
358+
}
359+
Err(err) => err,
360+
};
361+
351362
if let Some(mut status) = find_status_in_source_chain(&*err) {
352363
status.source = Some(err.into());
353364
return Ok(status);

tonic/src/transport/server/mod.rs

+29
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ use tower::{
6868
layer::util::{Identity, Stack},
6969
layer::Layer,
7070
limit::concurrency::ConcurrencyLimitLayer,
71+
load_shed::LoadShedLayer,
7172
util::BoxCloneService,
7273
Service, ServiceBuilder, ServiceExt,
7374
};
@@ -89,6 +90,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20;
8990
pub struct Server<L = Identity> {
9091
trace_interceptor: Option<TraceInterceptor>,
9192
concurrency_limit: Option<usize>,
93+
load_shed: bool,
9294
timeout: Option<Duration>,
9395
#[cfg(feature = "_tls-any")]
9496
tls: Option<TlsAcceptor>,
@@ -113,6 +115,7 @@ impl Default for Server<Identity> {
113115
Self {
114116
trace_interceptor: None,
115117
concurrency_limit: None,
118+
load_shed: false,
116119
timeout: None,
117120
#[cfg(feature = "_tls-any")]
118121
tls: None,
@@ -181,6 +184,27 @@ impl<L> Server<L> {
181184
}
182185
}
183186

187+
/// Enable or disable load shedding. The default is disabled.
188+
///
189+
/// When load shedding is enabled, if the service responds with not ready
190+
/// the request will immediately be rejected with a
191+
/// [`resource_exhausted`](https://docs.rs/tonic/latest/tonic/struct.Status.html#method.resource_exhausted) error.
192+
/// The default is to buffer requests. This is especially useful in combination with
193+
/// setting a concurrency limit per connection.
194+
///
195+
/// # Example
196+
///
197+
/// ```
198+
/// # use tonic::transport::Server;
199+
/// # use tower_service::Service;
200+
/// # let builder = Server::builder();
201+
/// builder.load_shed(true);
202+
/// ```
203+
#[must_use]
204+
pub fn load_shed(self, load_shed: bool) -> Self {
205+
Server { load_shed, ..self }
206+
}
207+
184208
/// Set a timeout on for all request handlers.
185209
///
186210
/// # Example
@@ -516,6 +540,7 @@ impl<L> Server<L> {
516540
service_builder: self.service_builder.layer(new_layer),
517541
trace_interceptor: self.trace_interceptor,
518542
concurrency_limit: self.concurrency_limit,
543+
load_shed: self.load_shed,
519544
timeout: self.timeout,
520545
#[cfg(feature = "_tls-any")]
521546
tls: self.tls,
@@ -645,6 +670,7 @@ impl<L> Server<L> {
645670
{
646671
let trace_interceptor = self.trace_interceptor.clone();
647672
let concurrency_limit = self.concurrency_limit;
673+
let load_shed = self.load_shed;
648674
let init_connection_window_size = self.init_connection_window_size;
649675
let init_stream_window_size = self.init_stream_window_size;
650676
let max_concurrent_streams = self.max_concurrent_streams;
@@ -671,6 +697,7 @@ impl<L> Server<L> {
671697
let mut svc = MakeSvc {
672698
inner: svc,
673699
concurrency_limit,
700+
load_shed,
674701
timeout,
675702
trace_interceptor,
676703
_io: PhantomData,
@@ -1056,6 +1083,7 @@ impl<S> fmt::Debug for Svc<S> {
10561083
#[derive(Clone)]
10571084
struct MakeSvc<S, IO> {
10581085
concurrency_limit: Option<usize>,
1086+
load_shed: bool,
10591087
timeout: Option<Duration>,
10601088
inner: S,
10611089
trace_interceptor: Option<TraceInterceptor>,
@@ -1089,6 +1117,7 @@ where
10891117

10901118
let svc = ServiceBuilder::new()
10911119
.layer(RecoverErrorLayer::new())
1120+
.option_layer(self.load_shed.then_some(LoadShedLayer::new()))
10921121
.option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
10931122
.layer_fn(|s| GrpcTimeout::new(s, timeout))
10941123
.service(svc);

0 commit comments

Comments
 (0)