diff --git a/tonic/src/metadata/map.rs b/tonic/src/metadata/map.rs index 8ddccf194..cbef3e101 100644 --- a/tonic/src/metadata/map.rs +++ b/tonic/src/metadata/map.rs @@ -194,7 +194,6 @@ pub struct OccupiedEntry<'a, VE: ValueEncoding> { phantom: PhantomData, } -#[cfg(feature = "transport")] pub(crate) const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; // ===== impl MetadataMap ===== diff --git a/tonic/src/metadata/mod.rs b/tonic/src/metadata/mod.rs index 50bfb49e4..4e796748f 100644 --- a/tonic/src/metadata/mod.rs +++ b/tonic/src/metadata/mod.rs @@ -29,7 +29,6 @@ pub use self::value::AsciiMetadataValue; pub use self::value::BinaryMetadataValue; pub use self::value::MetadataValue; -#[cfg(feature = "transport")] pub(crate) use self::map::GRPC_TIMEOUT_HEADER; /// The metadata::errors module contains types for errors that can occur diff --git a/tonic/src/request.rs b/tonic/src/request.rs index f2f047ffc..7d8f80260 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,11 +1,11 @@ -use crate::metadata::MetadataMap; +use crate::metadata::{MetadataMap, MetadataValue}; #[cfg(feature = "transport")] use crate::transport::Certificate; use futures_core::Stream; use http::Extensions; -use std::net::SocketAddr; #[cfg(feature = "transport")] use std::sync::Arc; +use std::{net::SocketAddr, time::Duration}; /// A gRPC request and metadata from an RPC call. #[derive(Debug)] @@ -221,6 +221,39 @@ impl Request { pub(crate) fn get(&self) -> Option<&I> { self.extensions.get::() } + + /// Set the max duration the request is allowed to take. + /// + /// Requires the server to support the `grpc-timeout` metadata, which Tonic does. + /// + /// The duration will be formatted according to [the spec] and use the most precise unit + /// possible. + /// + /// Example: + /// + /// ```rust + /// use std::time::Duration; + /// use tonic::Request; + /// + /// let mut request = Request::new(()); + /// + /// request.set_timeout(Duration::from_secs(30)); + /// + /// let value = request.metadata().get("grpc-timeout").unwrap(); + /// + /// assert_eq!( + /// value, + /// // equivalent to 30 seconds + /// "30000000u" + /// ); + /// ``` + /// + /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + pub fn set_timeout(&mut self, deadline: Duration) { + let value = MetadataValue::from_str(&duration_to_grpc_timeout(deadline)).unwrap(); + self.metadata_mut() + .insert(crate::metadata::GRPC_TIMEOUT_HEADER, value); + } } impl IntoRequest for T { @@ -265,6 +298,40 @@ mod sealed { pub trait Sealed {} } +fn duration_to_grpc_timeout(duration: Duration) -> String { + fn try_format>( + duration: Duration, + unit: char, + convert: impl FnOnce(Duration) -> T, + ) -> Option { + // The gRPC spec specifies that the timeout most be at most 8 digits. So this is the largest a + // value can be before we need to use a bigger unit. + let max_size: u128 = 99_999_999; // exactly 8 digits + + let value = convert(duration).into(); + if value > max_size { + None + } else { + Some(format!("{}{}", value, unit)) + } + } + + // pick the most precise unit that is less than or equal to 8 digits as per the gRPC spec + try_format(duration, 'n', |d| d.as_nanos()) + .or_else(|| try_format(duration, 'u', |d| d.as_micros())) + .or_else(|| try_format(duration, 'm', |d| d.as_millis())) + .or_else(|| try_format(duration, 'S', |d| d.as_secs())) + .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60)) + .or_else(|| { + try_format(duration, 'H', |d| { + let minutes = d.as_secs() / 60; + minutes / 60 + }) + }) + // duration has to be more than 11_415 years for this to happen + .expect("duration is unrealistically large") +} + #[cfg(test)] mod tests { use super::*; @@ -283,4 +350,25 @@ mod tests { let http_request = r.into_http(Uri::default()); assert!(http_request.headers().is_empty()); } + + #[test] + fn duration_to_grpc_timeout_less_than_second() { + let timeout = Duration::from_millis(500); + let value = duration_to_grpc_timeout(timeout); + assert_eq!(value, format!("{}u", timeout.as_micros())); + } + + #[test] + fn duration_to_grpc_timeout_more_than_second() { + let timeout = Duration::from_secs(30); + let value = duration_to_grpc_timeout(timeout); + assert_eq!(value, format!("{}u", timeout.as_micros())); + } + + #[test] + fn duration_to_grpc_timeout_a_very_long_time() { + let one_hour = Duration::from_secs(60 * 60); + let value = duration_to_grpc_timeout(one_hour); + assert_eq!(value, format!("{}m", one_hour.as_millis())); + } }