Skip to content

Commit

Permalink
refactor ratelimit a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Jun 3, 2024
1 parent 13a0501 commit a6c30a5
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ tokio = { version = "1.36", features = ["full"] }
tokio-util = { version = "0.7", features = ["full"] }
tokio-rustls = "0.26"
tower = "0.4"
tower_governor = "0.3"
tower_governor = "0.4"
tower-http = { version = "0.5", features = [
"trace",
"request-id",
Expand Down
4 changes: 2 additions & 2 deletions src/policy/denylist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use url::Url;

use super::load_canister_list;
use super::load_principal_list;
use crate::{http::Client, routing::middleware::geoip::CountryCode, tasks::Run};

pub struct Denylist {
Expand Down Expand Up @@ -59,7 +59,7 @@ impl Denylist {
registry: &Registry,
) -> Result<Self, Error> {
let allowlist = if let Some(v) = allowlist {
let r = load_canister_list(&v).context("unable to read allowlist")?;
let r = load_principal_list(&v).context("unable to read allowlist")?;
warn!("Denylist allowlist loaded: {}", r.len());
r
} else {
Expand Down
6 changes: 3 additions & 3 deletions src/policy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use std::{collections::HashSet, fs, path::PathBuf};
use anyhow::{Context, Error};
use candid::Principal;

// Generic function to load a list of canister ids from a text file into a HashSet
pub fn load_canister_list(path: &PathBuf) -> Result<HashSet<Principal>, Error> {
let data = fs::read_to_string(path).context("failed to read canisters file")?;
// Generic function to load a list of principals from a text file into a HashSet
pub fn load_principal_list(path: &PathBuf) -> Result<HashSet<Principal>, Error> {
let data = fs::read_to_string(path).context("failed to read file")?;
let set = data
.lines()
.filter(|x| !x.trim().is_empty())
Expand Down
4 changes: 2 additions & 2 deletions src/routing/middleware/canister_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{

use crate::{
cli::Cli,
policy::{domain_canister::DomainCanisterMatcher, load_canister_list},
policy::{domain_canister::DomainCanisterMatcher, load_principal_list},
routing::{ErrorCause, RequestCtx},
};

Expand All @@ -19,7 +19,7 @@ pub struct CanisterMatcherState(Arc<DomainCanisterMatcher>);
impl CanisterMatcherState {
pub fn new(cli: &Cli) -> Result<Self, Error> {
let pre_isolation_canisters = if let Some(v) = cli.policy.pre_isolation_canisters.as_ref() {
load_canister_list(v).context("unable to load pre-isolation canisters")?
load_principal_list(v).context("unable to load pre-isolation canisters")?
} else {
HashSet::new()
};
Expand Down
25 changes: 11 additions & 14 deletions src/routing/middleware/rate_limiter.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{net::IpAddr, sync::Arc, time::Duration};

use ::governor::{clock::QuantaInstant, middleware::NoOpMiddleware};
use anyhow::{anyhow, Error};
use axum::{extract::Request, response::IntoResponse};
use tower::{
layer::util::{Identity, Stack},
ServiceBuilder,
Layer, Service, ServiceBuilder,
};
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::KeyExtractor, GovernorError, GovernorLayer,
Expand All @@ -31,15 +32,17 @@ impl KeyExtractor for IpKeyExtractor {
}
}

pub fn build_rate_limiter_middleware<T: KeyExtractor>(
pub fn build_middleware<T: KeyExtractor + Send + Sync + 'static>(
rps: u32,
burst_size: u32,
key_extractor: T,
rate_limit_cause: RateLimitCause,
) -> Option<ServiceBuilder<Stack<GovernorLayer<'static, T, NoOpMiddleware<QuantaInstant>>, Identity>>>
{
let period = Duration::from_secs(1).checked_div(rps)?;
let governor_conf = Box::new(
) -> Result<GovernorLayer<T, NoOpMiddleware<QuantaInstant>>, Error> {
let period = Duration::from_secs(1)
.checked_div(rps)
.ok_or_else(|| anyhow!("RPS is zero"))?;

let config = Arc::new(
GovernorConfigBuilder::default()
.period(period)
.error_handler(move |err| match err {
Expand All @@ -51,20 +54,14 @@ pub fn build_rate_limiter_middleware<T: KeyExtractor>(
}
GovernorError::Other { code, msg, headers } => {
let msg = format!("Rate limiter failed unexpectedly: code={code}, msg={msg:?}, headers={headers:?}");
debug!("{msg}");
ErrorCause::Other(msg).into_response()
}
})
.burst_size(burst_size)
.key_extractor(key_extractor)
.finish()?,
);

let gov_layer = GovernorLayer {
config: Box::leak(governor_conf),
};
.finish().ok_or_else(|| anyhow!("unable to build governor config"))?);

Some(ServiceBuilder::new().layer(gov_layer))
Ok(GovernorLayer { config })
}

#[cfg(test)]
Expand Down
13 changes: 11 additions & 2 deletions src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ use crate::{
http::Client,
log::clickhouse::Clickhouse,
metrics,
routing::middleware::{canister_match, geoip, headers, request_id, validate},
routing::middleware::{canister_match, geoip, headers, rate_limiter, request_id, validate},
tasks::TaskManager,
};

use self::middleware::denylist;
use self::{error_cause::RateLimitCause, middleware::denylist};

use {
canister::{Canister, ResolvesCanister},
Expand Down Expand Up @@ -164,6 +164,15 @@ pub fn setup_router(
cli.cert.issuer_urls.clone(),
));
let router_issuer = Router::new()
.layer(
rate_limiter::build_middleware(
5,
10,
rate_limiter::IpKeyExtractor,
RateLimitCause::Normal,
)
.unwrap(),
)
.route(
"/:id",
get(proxy::issuer_proxy)
Expand Down

0 comments on commit a6c30a5

Please sign in to comment.