Skip to content

Commit

Permalink
feat(transport): Expose more granular control of TLS configuration (#48)
Browse files Browse the repository at this point in the history
This commit reworks TLS configuration of both servers and endpoints in
order to provide a more flexible API. We now add options to configure
the selected TLS library using the appropriate 'native' configuration
structures, as well as retaining the existing simplier interface which
is compatible with both.

The new API can also be easily extended to support simple interfaces for
configuring mTLS and a range of other options without creating sprawl
in the builders for `Server` and `Endpoint`.
  • Loading branch information
jen20 authored and LucioFranco committed Oct 8, 2019
1 parent 4628ff0 commit 8db3961
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 131 deletions.
10 changes: 7 additions & 3 deletions tonic-examples/src/tls/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@ pub mod pb {
}

use pb::{client::EchoClient, EchoRequest};
use tonic::transport::{Certificate, Channel};
use tonic::transport::{Certificate, Channel, ClientTlsConfig};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let pem = tokio::fs::read("tonic-examples/data/tls/ca.pem").await?;
let ca = Certificate::from_pem(pem);

let tls = ClientTlsConfig::with_rustls()
.ca_certificate(ca)
.domain_name("example.com")
.clone();

let channel = Channel::from_static("http://[::1]:50051")
.rustls_tls(ca, Some("example.com".into()))
.tls_config(&tls)
.channel();

let mut client = EchoClient::new(channel);

let request = tonic::Request::new(EchoRequest {
message: "hello".into(),
});
Expand Down
4 changes: 2 additions & 2 deletions tonic-examples/src/tls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod pb {
use pb::{EchoRequest, EchoResponse};
use std::collections::VecDeque;
use tonic::{
transport::{Identity, Server},
transport::{Identity, Server, ServerTlsConfig},
Request, Response, Status, Streaming,
};

Expand Down Expand Up @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let server = EchoServer::default();

Server::builder()
.rustls_tls(identity)
.tls_config(ServerTlsConfig::with_rustls().identity(identity))
.clone()
.serve(addr, pb::server::EchoServer::new(server))
.await?;
Expand Down
9 changes: 7 additions & 2 deletions tonic-interop/src/bin/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;
use structopt::{clap::arg_enum, StructOpt};
use tonic::transport::{Certificate, Endpoint};
use tonic::transport::{Certificate, ClientTlsConfig, Endpoint};
use tonic_interop::client;

#[derive(StructOpt)]
Expand Down Expand Up @@ -33,7 +33,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if matches.use_tls {
let pem = tokio::fs::read("tonic-interop/data/ca.pem").await?;
let ca = Certificate::from_pem(pem);
endpoint.openssl_tls(ca, Some("foo.test.google.fr".into()));

endpoint.tls_config(
ClientTlsConfig::with_openssl()
.ca_certificate(ca)
.domain_name("foo.test.google.fr"),
);
}

let channel = endpoint.channel();
Expand Down
4 changes: 2 additions & 2 deletions tonic-interop/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use http::header::HeaderName;
use structopt::StructOpt;
use tonic::body::BoxBody;
use tonic::client::GrpcService;
use tonic::transport::{Identity, Server};
use tonic::transport::{Identity, Server, ServerTlsConfig};
use tonic_interop::{server, MergeTrailers};

#[derive(StructOpt)]
Expand All @@ -26,7 +26,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let key = tokio::fs::read("tonic-interop/data/server1.key").await?;

let identity = Identity::from_pem(cert, key);
builder.openssl_tls(identity);
builder.tls_config(ServerTlsConfig::with_openssl().identity(identity));
}

builder.interceptor_fn(|svc, req| {
Expand Down
171 changes: 112 additions & 59 deletions tonic/src/transport/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::channel::Channel;
#[cfg(feature = "tls")]
use super::{service::TlsConnector, tls::Certificate};
use super::{
service::TlsConnector,
tls::{Certificate, TlsProvider},
};
use bytes::Bytes;
use http::uri::{InvalidUriBytes, Uri};
use std::{
Expand Down Expand Up @@ -122,64 +125,6 @@ impl Endpoint {
self
}

/// Enable TLS and apply the CA as the root certificate.
///
/// Providing an optional domain to override. If `None` is passed to this
/// the TLS implementation will use the `Uri` that was used to create the
/// `Endpoint` builder.
///
/// ```no_run
/// # use tonic::transport::{Certificate, Endpoint};
/// # fn dothing() -> Result<(), Box<dyn std::error::Error>> {
/// # let mut builder = Endpoint::from_static("https://example.com");
/// let ca = std::fs::read_to_string("ca.pem")?;
///
/// let ca = Certificate::from_pem(ca);
///
/// builder.openssl_tls(ca, "example.com".to_string());
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "openssl")]
#[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub fn openssl_tls(&mut self, ca: Certificate, domain: impl Into<Option<String>>) -> &mut Self {
let domain = domain
.into()
.unwrap_or_else(|| self.uri.clone().to_string());
let tls = TlsConnector::new_with_openssl(ca, domain).unwrap();
self.tls = Some(tls);
self
}

/// Enable TLS and apply the CA as the root certificate.
///
/// Providing an optional domain to override. If `None` is passed to this
/// the TLS implementation will use the `Uri` that was used to create the
/// `Endpoint` builder.
///
/// ```no_run
/// # use tonic::transport::{Certificate, Endpoint};
/// # fn dothing() -> Result<(), Box<dyn std::error::Error>> {
/// # let mut builder = Endpoint::from_static("https://example.com");
/// let ca = std::fs::read_to_string("ca.pem")?;
///
/// let ca = Certificate::from_pem(ca);
///
/// builder.rustls_tls(ca, "example.com".to_string());
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "rustls")]
#[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
pub fn rustls_tls(&mut self, ca: Certificate, domain: impl Into<Option<String>>) -> &mut Self {
let domain = domain
.into()
.unwrap_or_else(|| self.uri.clone().to_string());
let tls = TlsConnector::new_with_rustls(ca, domain).unwrap();
self.tls = Some(tls);
self
}

/// Intercept outbound HTTP Request headers;
pub fn intercept_headers<F>(&mut self, f: F) -> &mut Self
where
Expand All @@ -189,6 +134,13 @@ impl Endpoint {
self
}

/// Configures TLS for the endpoint.
#[cfg(feature = "tls")]
pub fn tls_config(&mut self, tls_config: &ClientTlsConfig) -> &mut Self {
self.tls = Some(tls_config.tls_connector(self.uri.clone()).unwrap());
self
}

/// Create a channel from this config.
pub fn channel(&self) -> Channel {
Channel::connect(self.clone())
Expand Down Expand Up @@ -252,3 +204,104 @@ impl fmt::Debug for Endpoint {
f.debug_struct("Endpoint").finish()
}
}

/// Configures TLS settings for endpoints.
#[cfg(feature = "tls")]
#[derive(Clone)]
pub struct ClientTlsConfig {
provider: TlsProvider,
domain: Option<String>,
cert: Option<Certificate>,
#[cfg(feature = "openssl")]
openssl_raw: Option<openssl1::ssl::SslConnector>,
#[cfg(feature = "rustls")]
rustls_raw: Option<tokio_rustls::rustls::ClientConfig>,
}

#[cfg(feature = "tls")]
impl fmt::Debug for ClientTlsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientTlsConfig")
.field("provider", &self.provider)
.finish()
}
}

#[cfg(feature = "tls")]
impl ClientTlsConfig {
/// Creates a new `ClientTlsConfig` using OpenSSL.
#[cfg(feature = "openssl")]
pub fn with_openssl() -> Self {
Self::new(TlsProvider::OpenSsl)
}

/// Creates a new `ClientTlsConfig` using Rustls.
#[cfg(feature = "rustls")]
pub fn with_rustls() -> Self {
Self::new(TlsProvider::Rustls)
}

fn new(provider: TlsProvider) -> Self {
ClientTlsConfig {
provider,
domain: None,
cert: None,
#[cfg(feature = "openssl")]
openssl_raw: None,
#[cfg(feature = "rustls")]
rustls_raw: None,
}
}

/// Sets the domain name against which to verify the server's TLS certificate.
pub fn domain_name(&mut self, domain_name: impl Into<String>) -> &mut Self {
self.domain = Some(domain_name.into());
self
}

/// Sets the CA Certificate against which to verify the server's TLS certificate.
pub fn ca_certificate(&mut self, ca_certificate: Certificate) -> &mut Self {
self.cert = Some(ca_certificate);
self
}

/// Use options specified by the given `SslConnector` to configure TLS.
///
/// This overrides all other TLS options set via other means.
#[cfg(feature = "openssl")]
pub fn openssl_connector(&mut self, connector: openssl1::ssl::SslConnector) -> &mut Self {
self.openssl_raw = Some(connector);
self
}

/// Use options specified by the given `ClientConfig` to configure TLS.
///
/// This overrides all other TLS options set via other means.
#[cfg(feature = "rustls")]
pub fn rustls_client_config(
&mut self,
config: tokio_rustls::rustls::ClientConfig,
) -> &mut Self {
self.rustls_raw = Some(config);
self
}

fn tls_connector(&self, uri: Uri) -> Result<TlsConnector, crate::Error> {
let domain = match &self.domain {
None => uri.to_string(),
Some(domain) => domain.clone(),
};
match self.provider {
#[cfg(feature = "openssl")]
TlsProvider::OpenSsl => match &self.openssl_raw {
None => TlsConnector::new_with_openssl_cert(self.cert.clone(), domain),
Some(r) => TlsConnector::new_with_openssl_raw(r.clone(), domain),
},
#[cfg(feature = "rustls")]
TlsProvider::Rustls => match &self.rustls_raw {
None => TlsConnector::new_with_rustls_cert(self.cert.clone(), domain),
Some(c) => TlsConnector::new_with_rustls_raw(c.clone(), domain),
},
}
}
}
16 changes: 12 additions & 4 deletions tonic/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! ## Client
//!
//! ```no_run
//! # use tonic::transport::{Channel, Certificate};
//! # use tonic::transport::{Channel, Certificate, ClientTlsConfig};
//! # use std::time::Duration;
//! # use tonic::body::BoxBody;
//! # use tonic::client::GrpcService;;
Expand All @@ -29,7 +29,9 @@
//! let cert = std::fs::read_to_string("ca.pem")?;
//!
//! let mut channel = Channel::from_static("https://example.com")
//! .rustls_tls(Certificate::from_pem(&cert), "example.com".to_string())
//! .tls_config(ClientTlsConfig::with_rustls()
//! .ca_certificate(Certificate::from_pem(&cert))
//! .domain_name("example.com".to_string()))
//! .timeout(Duration::from_secs(5))
//! .rate_limit(5, Duration::from_secs(1))
//! .concurrency_limit(256)
Expand All @@ -43,7 +45,7 @@
//! ## Server
//!
//! ```no_run
//! # use tonic::transport::{Server, Identity};
//! # use tonic::transport::{Server, Identity, ServerTlsConfig};
//! # use tower::{Service, service_fn};
//! # use futures_util::future::{err, ok};
//! # #[cfg(feature = "rustls")]
Expand All @@ -55,7 +57,8 @@
//! let addr = "[::1]:50051".parse()?;
//!
//! Server::builder()
//! .rustls_tls(Identity::from_pem(&cert, &key))
//! .tls_config(ServerTlsConfig::with_rustls()
//! .identity(Identity::from_pem(&cert, &key)))
//! .concurrency_limit_per_connection(256)
//! .interceptor_fn(|svc, req| {
//! println!("Request: {:?}", req);
Expand Down Expand Up @@ -89,4 +92,9 @@ pub use self::server::Server;
pub use self::tls::{Certificate, Identity};
pub use hyper::Body;

#[cfg(feature = "tls")]
pub use self::endpoint::ClientTlsConfig;
#[cfg(feature = "tls")]
pub use self::server::ServerTlsConfig;

pub(crate) use self::error::ErrorKind;
Loading

0 comments on commit 8db3961

Please sign in to comment.