Skip to content

Commit

Permalink
add tls session storage & ticketer
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Jun 1, 2024
1 parent 9f374e5 commit 775ef49
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 29 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "HTTP-to-IC gateway"
edition = "2021"

[dependencies]
ahash = "0.8"
anyhow = "1.0"
arc-swap = "1"
async-scoped = { version = "0.8", features = ["use-tokio"] }
Expand All @@ -15,6 +16,7 @@ axum-server = { version = "0.6", features = ["tls-rustls"] }
backoff = { version = "0.4", features = ["tokio"] }
bytes = "1.5"
candid = "0.10"
chacha20poly1305 = "0.10"
chrono = "0.4"
clap = { version = "4.5", features = ["derive", "string"] }
clap_derive = "4.5"
Expand Down Expand Up @@ -63,6 +65,7 @@ moka = { version = "0.12", features = ["sync", "future"] }
num-bigint = { version = "0.4.5", features = ["serde"] }
ocsp-stapler = "0.3"
once_cell = "1.19"
parse-size = { version = "1.0", features = ["std"] }
prometheus = "0.13"
rand = "0.8"
rasn = "0.15"
Expand Down
18 changes: 17 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ use crate::{
tls::{self, acme},
};

fn parse_size(s: &str) -> Result<u64, parse_size::Error> {
parse_size::Config::new().with_binary().parse_size(s)
}

#[derive(Parser)]
#[clap(name = SERVICE_NAME)]
#[clap(author = AUTHOR_NAME)]
Expand Down Expand Up @@ -132,6 +136,18 @@ pub struct HttpServer {
/// How long to wait for the existing connections to finish before shutting down
#[clap(long = "http-server-grace-period", default_value = "10s", value_parser = parse_duration)]
pub grace_period: Duration,

/// Maximum size of cache to store TLS sessions in memory
#[clap(long = "http-server-tls-session-cache-size", default_value = "256MB", value_parser = parse_size)]
pub tls_session_cache_size: u64,

/// Maximum time that a TLS session key can stay in cache without being requested (Time-to-Idle)
#[clap(long = "http-server-tls-session-cache-tti", default_value = "18h", value_parser = parse_duration)]
pub tls_session_cache_tti: Duration,

/// Lifetime of a TLS1.3 ticket, due to key rotation the actual lifetime will be twice than this.
#[clap(long = "http-server-tls-ticket-lifetime", default_value = "9h", value_parser = parse_duration)]
pub tls_ticket_lifetime: Duration,
}

#[derive(Args)]
Expand Down Expand Up @@ -243,7 +259,7 @@ pub struct Acme {

/// Attempt to renew the certificates when less than this duration is left until expiration.
/// This works only with DNS challenge, ALPN currently starts to renew after half of certificate
/// lifetime has passed.
/// lifetime has passed (45d for LetsEncrypt)
#[clap(long = "acme-renew-before", value_parser = parse_duration, default_value = "30d")]
pub acme_renew_before: Duration,

Expand Down
7 changes: 6 additions & 1 deletion src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
let dns_resolver = http::dns::Resolver::new((&cli.dns).into());
let reqwest_client = http::client::new((&cli.http_client).into(), dns_resolver.clone())?;
let http_client = Arc::new(http::ReqwestClient::new(reqwest_client.clone()));
let tls_session_cache = Arc::new(tls::sessions::Storage::new(
cli.http_server.tls_session_cache_size,
cli.http_server.tls_session_cache_tti,
));
let clickhouse = if cli.log.clickhouse.log_clickhouse_url.is_some() {
Some(Arc::new(
log::clickhouse::Clickhouse::new(&cli.log.clickhouse)
Expand Down Expand Up @@ -106,6 +110,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {
http_client.clone(),
storage,
Arc::new(dns_resolver),
tls_session_cache.clone(),
&registry,
)
.await
Expand All @@ -122,7 +127,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> {

// Setup metrics
if let Some(addr) = cli.metrics.listen {
let router = metrics::setup(&registry, &mut tasks);
let router = metrics::setup(&registry, tls_session_cache, &mut tasks);

let srv = Arc::new(http::Server::new(
addr,
Expand Down
46 changes: 43 additions & 3 deletions src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::{
RequestCtx,
},
tasks::{Run, TaskManager},
tls::sessions,
};
use body::CountingBody;

Expand Down Expand Up @@ -67,14 +68,23 @@ impl MetricsCache {
pub struct MetricsRunner {
metrics_cache: Arc<RwLock<MetricsCache>>,
registry: Registry,
tls_session_cache: Arc<sessions::Storage>,
encoder: TextEncoder,

// Metrics
mem_allocated: IntGauge,
mem_resident: IntGauge,
tls_session_cache_count: IntGauge,
tls_session_cache_size: IntGauge,
}

// Snapshots & encodes the metrics for the handler to export
impl MetricsRunner {
pub fn new(metrics_cache: Arc<RwLock<MetricsCache>>, registry: &Registry) -> Self {
pub fn new(
metrics_cache: Arc<RwLock<MetricsCache>>,
registry: &Registry,
tls_session_cache: Arc<sessions::Storage>,
) -> Self {
let mem_allocated = register_int_gauge_with_registry!(
format!("memory_allocated"),
format!("Allocated memory in bytes"),
Expand All @@ -89,12 +99,29 @@ impl MetricsRunner {
)
.unwrap();

let tls_session_cache_count = register_int_gauge_with_registry!(
format!("tls_session_cache_count"),
format!("Number of TLS sessions in the cache"),
registry
)
.unwrap();

let tls_session_cache_size = register_int_gauge_with_registry!(
format!("tls_session_cache_size"),
format!("Size of TLS sessions in the cache"),
registry
)
.unwrap();

Self {
metrics_cache,
registry: registry.clone(),
tls_session_cache,
encoder: TextEncoder::new(),
mem_allocated,
mem_resident,
tls_session_cache_count,
tls_session_cache_size,
}
}
}
Expand All @@ -108,6 +135,11 @@ impl MetricsRunner {
self.mem_resident
.set(stats::resident::read().unwrap() as i64);

// Record TLS session stats
let stats = self.tls_session_cache.stats();
self.tls_session_cache_count.set(stats.entries as i64);
self.tls_session_cache_size.set(stats.size as i64);

// Get a snapshot of metrics
let metric_families = self.registry.gather();

Expand Down Expand Up @@ -158,9 +190,17 @@ async fn handler(State(state): State<Arc<RwLock<MetricsCache>>>) -> impl IntoRes
)
}

pub fn setup(registry: &Registry, tasks: &mut TaskManager) -> Router {
pub fn setup(
registry: &Registry,
tls_session_cache: Arc<sessions::Storage>,
tasks: &mut TaskManager,
) -> Router {
let cache = Arc::new(RwLock::new(MetricsCache::new(METRICS_CACHE_CAPACITY)));
let runner = Arc::new(MetricsRunner::new(cache.clone(), registry));
let runner = Arc::new(MetricsRunner::new(
cache.clone(),
registry,
tls_session_cache,
));
tasks.add("metrics_runner", runner);

Router::new()
Expand Down
19 changes: 8 additions & 11 deletions src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub fn setup_router(
.route("/status", get(proxy::api_proxy))
.with_state(state_api);

let router = Router::new()
let mut router = Router::new()
.nest("/api/v2", router_api)
.fallback(
get(handler::handler)
Expand All @@ -155,15 +155,14 @@ pub fn setup_router(
.layer(common_layers);

// Setup issuer proxy endpoint if we have them configured
let router = if !cli.cert.issuer_urls.is_empty() {
if !cli.cert.issuer_urls.is_empty() {
// Init it early to avoid threading races
lazy_static::initialize(&proxy::REGEX_REG_ID);

// Strip possible path from URLs
let mut urls = cli.cert.issuer_urls.clone();
urls.iter_mut().for_each(|x| x.set_path(""));

let state = Arc::new(proxy::IssuerProxyState::new(http_client, urls));
let state = Arc::new(proxy::IssuerProxyState::new(
http_client,
cli.cert.issuer_urls.clone(),
));
let router_issuer = Router::new()
.route(
"/:id",
Expand All @@ -174,10 +173,8 @@ pub fn setup_router(
.route("/", post(proxy::issuer_proxy))
.with_state(state);

router.nest("/registrations", router_issuer)
} else {
router
};
router = router.nest("/registrations", router_issuer)
}

Ok(router)
}
11 changes: 5 additions & 6 deletions src/routing/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::{
use anyhow::Error;
use axum::{
body::Body,
extract::{Path, Request, State},
extract::{OriginalUri, Path, Request, State},
response::{IntoResponse, Response},
};
use candid::Principal;
Expand Down Expand Up @@ -59,6 +59,7 @@ pub struct ApiProxyState {
// Proxies /api/v2/... endpoints to the IC
pub async fn api_proxy(
State(state): State<Arc<ApiProxyState>>,
OriginalUri(uri): OriginalUri,
principal: Option<Path<String>>,
request: Request,
) -> Result<impl IntoResponse, ErrorCause> {
Expand All @@ -75,7 +76,7 @@ pub async fn api_proxy(

// Append the query URL to the IC url
let url = url
.join(request.uri().path())
.join(uri.path())
.map_err(|e| ErrorCause::MalformedRequest(format!("incorrect URL: {e}")))?;

// Proxy the request
Expand All @@ -97,6 +98,7 @@ pub struct IssuerProxyState {
// Proxies /registrations endpoint to the certificate issuers if they're defined
pub async fn issuer_proxy(
State(state): State<Arc<IssuerProxyState>>,
OriginalUri(uri): OriginalUri,
id: Option<Path<String>>,
request: Request,
) -> Result<impl IntoResponse, ErrorCause> {
Expand All @@ -109,15 +111,12 @@ pub async fn issuer_proxy(
}
}

// Extract path part from the request
let path = request.uri().path();

// Pick next issuer using round-robin & generate request URL for it
// TODO should we do retries here?
let next = state.next.fetch_add(1, Ordering::SeqCst) % state.issuers.len();
let url = state.issuers[next]
.clone()
.join(path)
.join(uri.path())
.map_err(|_| ErrorCause::MalformedRequest("unable to parse path as URL part".into()))?;

let response = proxy(url, request, &state.http_client)
Expand Down
29 changes: 22 additions & 7 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
pub mod acme;
pub mod cert;
pub mod resolver;
pub mod sessions;
pub mod tickets;

use std::{fs, sync::Arc};
use std::{fs, sync::Arc, time::Duration};

use anyhow::{anyhow, Context, Error};
use fqdn::{Fqdn, FQDN};
Expand All @@ -11,11 +13,9 @@ use ocsp_stapler::Stapler;
use prometheus::Registry;
use rustls::{
client::{ClientConfig, ClientSessionMemoryCache, Resumption},
server::{
ResolvesServerCert as ResolvesServerCertRustls, ServerConfig, ServerSessionMemoryCache,
},
server::{ResolvesServerCert as ResolvesServerCertRustls, ServerConfig, StoresServerSessions},
version::{TLS12, TLS13},
RootCertStore,
RootCertStore, TicketSwitcher,
};
use rustls_acme::acme::ACME_TLS_ALPN_NAME;

Expand Down Expand Up @@ -49,14 +49,25 @@ pub fn sni_matches(host: &Fqdn, domains: &[FQDN], wildcard: bool) -> bool {

pub fn prepare_server_config(
resolver: Arc<dyn ResolvesServerCertRustls>,
session_storage: Arc<dyn StoresServerSessions + Send + Sync>,
additional_alpn: Vec<Vec<u8>>,
ticket_lifetime: Duration,
) -> ServerConfig {
let mut cfg = ServerConfig::builder_with_protocol_versions(&[&TLS13, &TLS12])
.with_no_client_auth()
.with_cert_resolver(resolver);

// Create custom session storage with higher limit to allow effective TLS session resumption
cfg.session_storage = ServerSessionMemoryCache::new(131_072);
// Set custom session storage with to allow effective TLS session resumption
cfg.session_storage = session_storage;

// Enable ticketer
let ticketer = TicketSwitcher::new(ticket_lifetime.as_secs() as u32, || {
Ok(Box::new(tickets::Ticketer::new()))
})
.unwrap();
cfg.ticketer = Arc::new(ticketer);

// Enable tickets
cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()];
cfg.alpn_protocols.extend_from_slice(&additional_alpn);

Expand Down Expand Up @@ -134,13 +145,15 @@ async fn setup_acme(
}

// Prepares the stuff needed for serving TLS
#[allow(clippy::too_many_arguments)]
pub async fn setup(
cli: &Cli,
tasks: &mut TaskManager,
domains: Vec<FQDN>,
http_client: Arc<dyn Client>,
storage: Arc<StorageKey>,
dns_resolver: Arc<dyn Resolves>,
tls_session_storage: Arc<dyn StoresServerSessions + Send + Sync>,
registry: &Registry,
) -> Result<ServerConfig, Error> {
let mut providers: Vec<Arc<dyn ProvidesCertificates>> = vec![];
Expand Down Expand Up @@ -195,11 +208,13 @@ pub async fn setup(
// Generate Rustls config
let config = prepare_server_config(
certificate_resolver,
tls_session_storage,
if cli.acme.acme_challenge == Some(Challenge::Alpn) {
vec![ACME_TLS_ALPN_NAME.to_vec()]
} else {
vec![]
},
cli.http_server.tls_ticket_lifetime,
);

Ok(config)
Expand Down
2 changes: 2 additions & 0 deletions src/tls/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub struct AggregatingResolver {
// Implement certificate resolving for Rustls
impl ResolvesServerCertRustls for AggregatingResolver {
fn resolve(&self, ch: ClientHello) -> Option<Arc<CertifiedKey>> {
println!("{:?}", ch.server_name());

// Iterate over our resolvers to find matching cert if any.
self.resolvers
.iter()
Expand Down
Loading

0 comments on commit 775ef49

Please sign in to comment.