diff --git a/Cargo.toml b/Cargo.toml index 4345229..34bfdfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,7 +104,7 @@ tokio = { version = "1.36", features = ["full"] } tokio-util = { version = "0.7", features = ["full"] } tokio-rustls = "0.26" tower = "0.4" -tower_governor = "0.3" +tower_governor = "0.4" tower-http = { version = "0.5", features = [ "trace", "request-id", diff --git a/src/policy/denylist.rs b/src/policy/denylist.rs index af8de35..1051e55 100644 --- a/src/policy/denylist.rs +++ b/src/policy/denylist.rs @@ -18,7 +18,7 @@ use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use url::Url; -use super::load_canister_list; +use super::load_principal_list; use crate::{http::Client, routing::middleware::geoip::CountryCode, tasks::Run}; pub struct Denylist { @@ -59,7 +59,7 @@ impl Denylist { registry: &Registry, ) -> Result { let allowlist = if let Some(v) = allowlist { - let r = load_canister_list(&v).context("unable to read allowlist")?; + let r = load_principal_list(&v).context("unable to read allowlist")?; warn!("Denylist allowlist loaded: {}", r.len()); r } else { diff --git a/src/policy/mod.rs b/src/policy/mod.rs index c6868c1..0213d9d 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -7,9 +7,9 @@ use std::{collections::HashSet, fs, path::PathBuf}; use anyhow::{Context, Error}; use candid::Principal; -// Generic function to load a list of canister ids from a text file into a HashSet -pub fn load_canister_list(path: &PathBuf) -> Result, Error> { - let data = fs::read_to_string(path).context("failed to read canisters file")?; +// Generic function to load a list of principals from a text file into a HashSet +pub fn load_principal_list(path: &PathBuf) -> Result, Error> { + let data = fs::read_to_string(path).context("failed to read file")?; let set = data .lines() .filter(|x| !x.trim().is_empty()) diff --git a/src/routing/middleware/canister_match.rs b/src/routing/middleware/canister_match.rs index f79571f..bd18978 100644 --- a/src/routing/middleware/canister_match.rs +++ b/src/routing/middleware/canister_match.rs @@ -9,7 +9,7 @@ use axum::{ use crate::{ cli::Cli, - policy::{domain_canister::DomainCanisterMatcher, load_canister_list}, + policy::{domain_canister::DomainCanisterMatcher, load_principal_list}, routing::{ErrorCause, RequestCtx}, }; @@ -19,7 +19,7 @@ pub struct CanisterMatcherState(Arc); impl CanisterMatcherState { pub fn new(cli: &Cli) -> Result { let pre_isolation_canisters = if let Some(v) = cli.policy.pre_isolation_canisters.as_ref() { - load_canister_list(v).context("unable to load pre-isolation canisters")? + load_principal_list(v).context("unable to load pre-isolation canisters")? } else { HashSet::new() }; diff --git a/src/routing/middleware/rate_limiter.rs b/src/routing/middleware/rate_limiter.rs index 96bc807..9f7bbab 100644 --- a/src/routing/middleware/rate_limiter.rs +++ b/src/routing/middleware/rate_limiter.rs @@ -1,10 +1,11 @@ use std::{net::IpAddr, sync::Arc, time::Duration}; use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware}; +use anyhow::{anyhow, Error}; use axum::{extract::Request, response::IntoResponse}; use tower::{ layer::util::{Identity, Stack}, - ServiceBuilder, + Layer, Service, ServiceBuilder, }; use tower_governor::{ governor::GovernorConfigBuilder, key_extractor::KeyExtractor, GovernorError, GovernorLayer, @@ -31,15 +32,17 @@ impl KeyExtractor for IpKeyExtractor { } } -pub fn build_rate_limiter_middleware( +pub fn build_middleware( rps: u32, burst_size: u32, key_extractor: T, rate_limit_cause: RateLimitCause, -) -> Option>, Identity>>> -{ - let period = Duration::from_secs(1).checked_div(rps)?; - let governor_conf = Box::new( +) -> Result>, Error> { + let period = Duration::from_secs(1) + .checked_div(rps) + .ok_or_else(|| anyhow!("RPS is zero"))?; + + let config = Arc::new( GovernorConfigBuilder::default() .period(period) .error_handler(move |err| match err { @@ -51,20 +54,14 @@ pub fn build_rate_limiter_middleware( } GovernorError::Other { code, msg, headers } => { let msg = format!("Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}"); - debug!("{msg}"); ErrorCause::Other(msg).into_response() } }) .burst_size(burst_size) .key_extractor(key_extractor) - .finish()?, - ); - - let gov_layer = GovernorLayer { - config: Box::leak(governor_conf), - }; + .finish().ok_or_else(|| anyhow!("unable to build governor config"))?); - Some(ServiceBuilder::new().layer(gov_layer)) + Ok(GovernorLayer { config }) } #[cfg(test)] diff --git a/src/routing/mod.rs b/src/routing/mod.rs index c350de6..910ef89 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -28,11 +28,11 @@ use crate::{ http::Client, log::clickhouse::Clickhouse, metrics, - routing::middleware::{canister_match, geoip, headers, request_id, validate}, + routing::middleware::{canister_match, geoip, headers, rate_limiter, request_id, validate}, tasks::TaskManager, }; -use self::middleware::denylist; +use self::{error_cause::RateLimitCause, middleware::denylist}; use { canister::{Canister, ResolvesCanister}, @@ -164,6 +164,15 @@ pub fn setup_router( cli.cert.issuer_urls.clone(), )); let router_issuer = Router::new() + .layer( + rate_limiter::build_middleware( + 5, + 10, + rate_limiter::IpKeyExtractor, + RateLimitCause::Normal, + ) + .unwrap(), + ) .route( "/:id", get(proxy::issuer_proxy)