From 775ef49d2ee434e8f1ff636b1ef46961e13f5f58 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Sat, 1 Jun 2024 16:32:06 +0200 Subject: [PATCH] add tls session storage & ticketer --- Cargo.toml | 3 ++ src/cli.rs | 18 ++++++- src/core.rs | 7 ++- src/metrics/mod.rs | 46 ++++++++++++++++-- src/routing/mod.rs | 19 ++++---- src/routing/proxy.rs | 11 ++--- src/tls/mod.rs | 29 +++++++++--- src/tls/resolver.rs | 2 + src/tls/sessions.rs | 106 +++++++++++++++++++++++++++++++++++++++++ src/tls/tickets.rs | 110 +++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 322 insertions(+), 29 deletions(-) create mode 100644 src/tls/sessions.rs create mode 100644 src/tls/tickets.rs diff --git a/Cargo.toml b/Cargo.toml index 3959679..c411c4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ description = "HTTP-to-IC gateway" edition = "2021" [dependencies] +ahash = "0.8" anyhow = "1.0" arc-swap = "1" async-scoped = { version = "0.8", features = ["use-tokio"] } @@ -15,6 +16,7 @@ axum-server = { version = "0.6", features = ["tls-rustls"] } backoff = { version = "0.4", features = ["tokio"] } bytes = "1.5" candid = "0.10" +chacha20poly1305 = "0.10" chrono = "0.4" clap = { version = "4.5", features = ["derive", "string"] } clap_derive = "4.5" @@ -63,6 +65,7 @@ moka = { version = "0.12", features = ["sync", "future"] } num-bigint = { version = "0.4.5", features = ["serde"] } ocsp-stapler = "0.3" once_cell = "1.19" +parse-size = { version = "1.0", features = ["std"] } prometheus = "0.13" rand = "0.8" rasn = "0.15" diff --git a/src/cli.rs b/src/cli.rs index 227f1ed..8c3c53b 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -17,6 +17,10 @@ use crate::{ tls::{self, acme}, }; +fn parse_size(s: &str) -> Result { + parse_size::Config::new().with_binary().parse_size(s) +} + #[derive(Parser)] #[clap(name = SERVICE_NAME)] #[clap(author = AUTHOR_NAME)] @@ -132,6 +136,18 @@ pub struct HttpServer { /// How long to wait for the existing connections to finish before shutting down #[clap(long = "http-server-grace-period", default_value = "10s", value_parser = parse_duration)] pub grace_period: Duration, + + /// Maximum size of cache to store TLS sessions in memory + #[clap(long = "http-server-tls-session-cache-size", default_value = "256MB", value_parser = parse_size)] + pub tls_session_cache_size: u64, + + /// Maximum time that a TLS session key can stay in cache without being requested (Time-to-Idle) + #[clap(long = "http-server-tls-session-cache-tti", default_value = "18h", value_parser = parse_duration)] + pub tls_session_cache_tti: Duration, + + /// Lifetime of a TLS1.3 ticket, due to key rotation the actual lifetime will be twice than this. + #[clap(long = "http-server-tls-ticket-lifetime", default_value = "9h", value_parser = parse_duration)] + pub tls_ticket_lifetime: Duration, } #[derive(Args)] @@ -243,7 +259,7 @@ pub struct Acme { /// Attempt to renew the certificates when less than this duration is left until expiration. /// This works only with DNS challenge, ALPN currently starts to renew after half of certificate - /// lifetime has passed. + /// lifetime has passed (45d for LetsEncrypt) #[clap(long = "acme-renew-before", value_parser = parse_duration, default_value = "30d")] pub acme_renew_before: Duration, diff --git a/src/core.rs b/src/core.rs index 45da465..102aa3e 100644 --- a/src/core.rs +++ b/src/core.rs @@ -47,6 +47,10 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { let dns_resolver = http::dns::Resolver::new((&cli.dns).into()); let reqwest_client = http::client::new((&cli.http_client).into(), dns_resolver.clone())?; let http_client = Arc::new(http::ReqwestClient::new(reqwest_client.clone())); + let tls_session_cache = Arc::new(tls::sessions::Storage::new( + cli.http_server.tls_session_cache_size, + cli.http_server.tls_session_cache_tti, + )); let clickhouse = if cli.log.clickhouse.log_clickhouse_url.is_some() { Some(Arc::new( log::clickhouse::Clickhouse::new(&cli.log.clickhouse) @@ -106,6 +110,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { http_client.clone(), storage, Arc::new(dns_resolver), + tls_session_cache.clone(), ®istry, ) .await @@ -122,7 +127,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { // Setup metrics if let Some(addr) = cli.metrics.listen { - let router = metrics::setup(®istry, &mut tasks); + let router = metrics::setup(®istry, tls_session_cache, &mut tasks); let srv = Arc::new(http::Server::new( addr, diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index a16ef13..d6e34ce 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -38,6 +38,7 @@ use crate::{ RequestCtx, }, tasks::{Run, TaskManager}, + tls::sessions, }; use body::CountingBody; @@ -67,14 +68,23 @@ impl MetricsCache { pub struct MetricsRunner { metrics_cache: Arc>, registry: Registry, + tls_session_cache: Arc, encoder: TextEncoder, + + // Metrics mem_allocated: IntGauge, mem_resident: IntGauge, + tls_session_cache_count: IntGauge, + tls_session_cache_size: IntGauge, } // Snapshots & encodes the metrics for the handler to export impl MetricsRunner { - pub fn new(metrics_cache: Arc>, registry: &Registry) -> Self { + pub fn new( + metrics_cache: Arc>, + registry: &Registry, + tls_session_cache: Arc, + ) -> Self { let mem_allocated = register_int_gauge_with_registry!( format!("memory_allocated"), format!("Allocated memory in bytes"), @@ -89,12 +99,29 @@ impl MetricsRunner { ) .unwrap(); + let tls_session_cache_count = register_int_gauge_with_registry!( + format!("tls_session_cache_count"), + format!("Number of TLS sessions in the cache"), + registry + ) + .unwrap(); + + let tls_session_cache_size = register_int_gauge_with_registry!( + format!("tls_session_cache_size"), + format!("Size of TLS sessions in the cache"), + registry + ) + .unwrap(); + Self { metrics_cache, registry: registry.clone(), + tls_session_cache, encoder: TextEncoder::new(), mem_allocated, mem_resident, + tls_session_cache_count, + tls_session_cache_size, } } } @@ -108,6 +135,11 @@ impl MetricsRunner { self.mem_resident .set(stats::resident::read().unwrap() as i64); + // Record TLS session stats + let stats = self.tls_session_cache.stats(); + self.tls_session_cache_count.set(stats.entries as i64); + self.tls_session_cache_size.set(stats.size as i64); + // Get a snapshot of metrics let metric_families = self.registry.gather(); @@ -158,9 +190,17 @@ async fn handler(State(state): State>>) -> impl IntoRes ) } -pub fn setup(registry: &Registry, tasks: &mut TaskManager) -> Router { +pub fn setup( + registry: &Registry, + tls_session_cache: Arc, + tasks: &mut TaskManager, +) -> Router { let cache = Arc::new(RwLock::new(MetricsCache::new(METRICS_CACHE_CAPACITY))); - let runner = Arc::new(MetricsRunner::new(cache.clone(), registry)); + let runner = Arc::new(MetricsRunner::new( + cache.clone(), + registry, + tls_session_cache, + )); tasks.add("metrics_runner", runner); Router::new() diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 5025608..c350de6 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -145,7 +145,7 @@ pub fn setup_router( .route("/status", get(proxy::api_proxy)) .with_state(state_api); - let router = Router::new() + let mut router = Router::new() .nest("/api/v2", router_api) .fallback( get(handler::handler) @@ -155,15 +155,14 @@ pub fn setup_router( .layer(common_layers); // Setup issuer proxy endpoint if we have them configured - let router = if !cli.cert.issuer_urls.is_empty() { + if !cli.cert.issuer_urls.is_empty() { // Init it early to avoid threading races lazy_static::initialize(&proxy::REGEX_REG_ID); - // Strip possible path from URLs - let mut urls = cli.cert.issuer_urls.clone(); - urls.iter_mut().for_each(|x| x.set_path("")); - - let state = Arc::new(proxy::IssuerProxyState::new(http_client, urls)); + let state = Arc::new(proxy::IssuerProxyState::new( + http_client, + cli.cert.issuer_urls.clone(), + )); let router_issuer = Router::new() .route( "/:id", @@ -174,10 +173,8 @@ pub fn setup_router( .route("/", post(proxy::issuer_proxy)) .with_state(state); - router.nest("/registrations", router_issuer) - } else { - router - }; + router = router.nest("/registrations", router_issuer) + } Ok(router) } diff --git a/src/routing/proxy.rs b/src/routing/proxy.rs index bc468f9..368e6a1 100644 --- a/src/routing/proxy.rs +++ b/src/routing/proxy.rs @@ -6,7 +6,7 @@ use std::sync::{ use anyhow::Error; use axum::{ body::Body, - extract::{Path, Request, State}, + extract::{OriginalUri, Path, Request, State}, response::{IntoResponse, Response}, }; use candid::Principal; @@ -59,6 +59,7 @@ pub struct ApiProxyState { // Proxies /api/v2/... endpoints to the IC pub async fn api_proxy( State(state): State>, + OriginalUri(uri): OriginalUri, principal: Option>, request: Request, ) -> Result { @@ -75,7 +76,7 @@ pub async fn api_proxy( // Append the query URL to the IC url let url = url - .join(request.uri().path()) + .join(uri.path()) .map_err(|e| ErrorCause::MalformedRequest(format!("incorrect URL: {e}")))?; // Proxy the request @@ -97,6 +98,7 @@ pub struct IssuerProxyState { // Proxies /registrations endpoint to the certificate issuers if they're defined pub async fn issuer_proxy( State(state): State>, + OriginalUri(uri): OriginalUri, id: Option>, request: Request, ) -> Result { @@ -109,15 +111,12 @@ pub async fn issuer_proxy( } } - // Extract path part from the request - let path = request.uri().path(); - // Pick next issuer using round-robin & generate request URL for it // TODO should we do retries here? let next = state.next.fetch_add(1, Ordering::SeqCst) % state.issuers.len(); let url = state.issuers[next] .clone() - .join(path) + .join(uri.path()) .map_err(|_| ErrorCause::MalformedRequest("unable to parse path as URL part".into()))?; let response = proxy(url, request, &state.http_client) diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 3e67dc2..f503e5a 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,8 +1,10 @@ pub mod acme; pub mod cert; pub mod resolver; +pub mod sessions; +pub mod tickets; -use std::{fs, sync::Arc}; +use std::{fs, sync::Arc, time::Duration}; use anyhow::{anyhow, Context, Error}; use fqdn::{Fqdn, FQDN}; @@ -11,11 +13,9 @@ use ocsp_stapler::Stapler; use prometheus::Registry; use rustls::{ client::{ClientConfig, ClientSessionMemoryCache, Resumption}, - server::{ - ResolvesServerCert as ResolvesServerCertRustls, ServerConfig, ServerSessionMemoryCache, - }, + server::{ResolvesServerCert as ResolvesServerCertRustls, ServerConfig, StoresServerSessions}, version::{TLS12, TLS13}, - RootCertStore, + RootCertStore, TicketSwitcher, }; use rustls_acme::acme::ACME_TLS_ALPN_NAME; @@ -49,14 +49,25 @@ pub fn sni_matches(host: &Fqdn, domains: &[FQDN], wildcard: bool) -> bool { pub fn prepare_server_config( resolver: Arc, + session_storage: Arc, additional_alpn: Vec>, + ticket_lifetime: Duration, ) -> ServerConfig { let mut cfg = ServerConfig::builder_with_protocol_versions(&[&TLS13, &TLS12]) .with_no_client_auth() .with_cert_resolver(resolver); - // Create custom session storage with higher limit to allow effective TLS session resumption - cfg.session_storage = ServerSessionMemoryCache::new(131_072); + // Set custom session storage with to allow effective TLS session resumption + cfg.session_storage = session_storage; + + // Enable ticketer + let ticketer = TicketSwitcher::new(ticket_lifetime.as_secs() as u32, || { + Ok(Box::new(tickets::Ticketer::new())) + }) + .unwrap(); + cfg.ticketer = Arc::new(ticketer); + + // Enable tickets cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()]; cfg.alpn_protocols.extend_from_slice(&additional_alpn); @@ -134,6 +145,7 @@ async fn setup_acme( } // Prepares the stuff needed for serving TLS +#[allow(clippy::too_many_arguments)] pub async fn setup( cli: &Cli, tasks: &mut TaskManager, @@ -141,6 +153,7 @@ pub async fn setup( http_client: Arc, storage: Arc, dns_resolver: Arc, + tls_session_storage: Arc, registry: &Registry, ) -> Result { let mut providers: Vec> = vec![]; @@ -195,11 +208,13 @@ pub async fn setup( // Generate Rustls config let config = prepare_server_config( certificate_resolver, + tls_session_storage, if cli.acme.acme_challenge == Some(Challenge::Alpn) { vec![ACME_TLS_ALPN_NAME.to_vec()] } else { vec![] }, + cli.http_server.tls_ticket_lifetime, ); Ok(config) diff --git a/src/tls/resolver.rs b/src/tls/resolver.rs index c544760..67f6a3e 100644 --- a/src/tls/resolver.rs +++ b/src/tls/resolver.rs @@ -23,6 +23,8 @@ pub struct AggregatingResolver { // Implement certificate resolving for Rustls impl ResolvesServerCertRustls for AggregatingResolver { fn resolve(&self, ch: ClientHello) -> Option> { + println!("{:?}", ch.server_name()); + // Iterate over our resolvers to find matching cert if any. self.resolvers .iter() diff --git a/src/tls/sessions.rs b/src/tls/sessions.rs new file mode 100644 index 0000000..d4e96f9 --- /dev/null +++ b/src/tls/sessions.rs @@ -0,0 +1,106 @@ +use std::time::Duration; + +use ahash::RandomState; +use moka::sync::Cache; +use rustls::server::StoresServerSessions; + +type Key = Vec; +type Val = Vec; + +fn weigher(k: &Key, v: &Val) -> u32 { + (k.len() + v.len()) as u32 +} + +pub struct Stats { + pub entries: u64, + pub size: u64, +} + +/// Stores TLS sessions for TLSv1.2 only. +/// SipHash is replaced with ~10x faster aHash. +/// see https://github.com/tkaitchuck/aHash/blob/master/compare/readme.md +#[derive(Debug)] +pub struct Storage { + cache: Cache, +} + +impl Storage { + pub fn new(capacity: u64, tti: Duration) -> Self { + let cache = Cache::builder() + .max_capacity(capacity) + .time_to_idle(tti) + .weigher(weigher) + .build_with_hasher(RandomState::default()); + + Self { cache } + } + + pub fn stats(&self) -> Stats { + self.cache.run_pending_tasks(); + Stats { + entries: self.cache.entry_count(), + size: self.cache.weighted_size(), + } + } +} + +impl StoresServerSessions for Storage { + fn get(&self, key: &[u8]) -> Option> { + self.cache.get(key) + } + + fn put(&self, key: Vec, value: Vec) -> bool { + self.cache.insert(key, value); + true + } + + fn take(&self, key: &[u8]) -> Option> { + self.cache.remove(key) + } + + fn can_cache(&self) -> bool { + true + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_storage() { + let c = Storage::new(10000, Duration::from_secs(3600)); + + let key1 = "a".repeat(2500).as_bytes().to_vec(); + let key2 = "b".repeat(2500).as_bytes().to_vec(); + let key3 = "b".as_bytes().to_vec(); + + // Check that two entries fit + c.put(key1.clone(), key1.clone()); + c.cache.run_pending_tasks(); + assert_eq!(c.cache.entry_count(), 1); + assert_eq!(c.cache.weighted_size(), 5000); + c.put(key2.clone(), key2.clone()); + c.cache.run_pending_tasks(); + assert_eq!(c.cache.entry_count(), 2); + assert_eq!(c.cache.weighted_size(), 10000); + + // Check that 3rd entry won't fit + c.put(key3.clone(), key3.clone()); + c.cache.run_pending_tasks(); + assert_eq!(c.cache.entry_count(), 2); + assert_eq!(c.cache.weighted_size(), 10000); + assert!(c.get(&key3).is_none()); + + // Check that keys are taken and not left + assert!(c.take(&key1).is_some()); + assert!(c.get(&key1).is_none()); + assert!(c.take(&key2).is_some()); + assert!(c.get(&key2).is_none()); + + // Check that nothing left + c.cache.run_pending_tasks(); + assert_eq!(c.cache.entry_count(), 0); + assert_eq!(c.cache.weighted_size(), 0); + } +} diff --git a/src/tls/tickets.rs b/src/tls/tickets.rs new file mode 100644 index 0000000..4a65e71 --- /dev/null +++ b/src/tls/tickets.rs @@ -0,0 +1,110 @@ +use std::{ + fmt, + sync::atomic::{AtomicU32, Ordering}, +}; + +use chacha20poly1305::{ + aead::{Aead, AeadCore, KeyInit, OsRng}, + XChaCha20Poly1305, XNonce, +}; +use rustls::server::ProducesTickets; + +// We're using 192-bit nonce +const NONCE_LEN: usize = 192 / 8; + +/// Encrypts & decrypts tickets for TLS 1.3 session resumption. +/// Must be used with rustls::ticketer::TicketSwitcher to facilitate key rotation. +/// We're using XChaCha20Poly1305 authenicated encryption (AEAD) +pub struct Ticketer { + counter: AtomicU32, + cipher: XChaCha20Poly1305, +} + +impl fmt::Debug for Ticketer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Ticketer") + } +} + +impl Ticketer { + pub fn new() -> Self { + // Generate random key that is valid for the lifetime of this ticketer + let key = XChaCha20Poly1305::generate_key(&mut OsRng); + let cipher = XChaCha20Poly1305::new(&key); + Self { + cipher, + counter: AtomicU32::new(0), + } + } + + /// Generates a random nonce and then replaces first 4 bytes of it with a counter. + /// Purely random nonces seem to be less secure, though 192-bit XNonce that we're using might be Ok. + /// See https://docs.rs/aead/latest/aead/trait.AeadCore.html#security-warning + fn nonce(&self) -> XNonce { + let mut nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng); + let count = self.counter.fetch_add(1, Ordering::SeqCst); + nonce[0..4].copy_from_slice(&count.to_le_bytes()); + nonce + } +} + +impl ProducesTickets for Ticketer { + fn enabled(&self) -> bool { + true + } + + fn decrypt(&self, cipher: &[u8]) -> Option> { + // Check if the ciphertext is too short + if cipher.len() <= NONCE_LEN { + return None; + } + + // Extract nonce + let nonce = XNonce::from_slice(&cipher[0..NONCE_LEN]); + + // Try to decrypt + let plaintext = self.cipher.decrypt(nonce, &cipher[NONCE_LEN..]).ok()?; + Some(plaintext) + } + + fn encrypt(&self, plain: &[u8]) -> Option> { + // Generate nonce & encrypt + let nonce = self.nonce(); + let ciphertext = self.cipher.encrypt(&nonce, plain).ok()?; + + // Concatenate nonce & ciphertext + let mut result = Vec::with_capacity(nonce.len() + ciphertext.len()); + result.extend_from_slice(nonce.as_slice()); + result.extend_from_slice(&ciphertext); + + Some(result) + } + + fn lifetime(&self) -> u32 { + // Lifetime here isn't important since it's designed to be used under TicketSwitcher + // which manages its own lifetimes + 3600 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ticketer() { + let t = Ticketer::new(); + + // Make sure that nonce is using a counter + for i in 0..10 { + let counter = u32::from_le_bytes(t.nonce().as_slice()[0..4].try_into().unwrap()); + assert_eq!(counter, i); + } + + // Check encryption & decryption + let msg = b"The quick brown fox jumps over the lazy dog"; + let ciphertext = t.encrypt(msg).unwrap(); + let plaintext = t.decrypt(&ciphertext).unwrap(); + assert_eq!(&msg[..], plaintext); + } +}