Skip to content

Commit

Permalink
make cert-issuer-provider poll itself internally
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Sep 20, 2024
1 parent c918364 commit b4df1ad
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 46 deletions.
6 changes: 5 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,12 @@ pub struct Cert {
#[clap(env, long, value_delimiter = ',')]
pub cert_provider_issuer_url: Vec<Url>,

/// How frequently to poll providers for certificates
/// How frequently to refresh certificate issuers
#[clap(env, long, default_value = "30s", value_parser = parse_duration)]
pub cert_provider_issuer_poll_interval: Duration,

/// How frequently to poll providers for certificates
#[clap(env, long, default_value = "5s", value_parser = parse_duration)]
pub cert_provider_poll_interval: Duration,

/// Disable OCSP stapling
Expand Down
119 changes: 76 additions & 43 deletions src/tls/cert/providers/issuer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ use std::{
};

use anyhow::{anyhow, Context as AnyhowContext};
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use candid::Principal;
use fqdn::FQDN;
use ic_bn_lib::http;
use ic_bn_lib::{http, tasks::Run};
use mockall::automock;
use reqwest::{Method, Request, StatusCode, Url};
use serde::Deserialize;
use tokio::sync::Mutex;
use tracing::info;
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};

use crate::routing::domain::{CustomDomain, ProvidesCustomDomains};
use verify::{Verify, VerifyError, WithVerify};

use super::{Pem, ProvidesCertificates};

const CACHE_TTL: Duration = Duration::from_secs(9);

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Expand Down Expand Up @@ -52,15 +52,11 @@ pub trait Import: Sync + Send {
async fn import(&self) -> Result<Vec<Package>, Error>;
}

struct Cache {
updated_at: Instant,
packages: Vec<Package>,
}

pub struct CertificatesImporter {
http_client: Arc<dyn http::Client>,
exporter_url: Url,
cache: Mutex<Cache>,
poll_interval: Duration,
snapshot: ArcSwapOption<Vec<Package>>,
}

impl std::fmt::Debug for CertificatesImporter {
Expand All @@ -70,28 +66,46 @@ impl std::fmt::Debug for CertificatesImporter {
}

impl CertificatesImporter {
pub fn new(http_client: Arc<dyn http::Client>, mut exporter_url: Url) -> Self {
pub fn new(
http_client: Arc<dyn http::Client>,
mut exporter_url: Url,
poll_interval: Duration,
) -> Self {
exporter_url.set_path("");
let exporter_url = exporter_url.join("/certificates").unwrap();

Self {
http_client,
exporter_url,
cache: Mutex::new(Cache {
updated_at: Instant::now().checked_sub(CACHE_TTL * 2).unwrap(),
packages: vec![],
}),
poll_interval,
snapshot: ArcSwapOption::empty(),
}
}

async fn refresh(&self) -> Result<(), Error> {
let start = Instant::now();
let packages = self.import().await.context("unable to fetch packages")?;
info!(
"{self:?}: {} certs loaded in {}s",
packages.len(),
start.elapsed().as_secs_f64()
);

self.snapshot.store(Some(Arc::new(packages)));
Ok(())
}
}

#[async_trait]
impl ProvidesCustomDomains for CertificatesImporter {
async fn get_custom_domains(&self) -> Result<Vec<CustomDomain>, anyhow::Error> {
let domains = self
.import()
.await?
.into_iter()
let packages = self
.snapshot
.load_full()
.ok_or_else(|| anyhow!("no packages fetched yet"))?;

let domains = packages
.iter()
.map(|x| -> Result<_, anyhow::Error> {
Ok(CustomDomain {
name: FQDN::from_str(&x.name)?,
Expand All @@ -107,63 +121,79 @@ impl ProvidesCustomDomains for CertificatesImporter {
#[async_trait]
impl ProvidesCertificates for CertificatesImporter {
async fn get_certificates(&self) -> Result<Vec<Pem>, anyhow::Error> {
let certs = self
.import()
.await?
let packages = self
.snapshot
.load_full()
.ok_or_else(|| anyhow!("no packages fetched yet"))?;

let certs = packages
.as_ref()
.clone()
.into_iter()
.map(|x| Pem {
cert: x.pair.1,
key: x.pair.0,
})
.collect::<Vec<_>>();

info!(
"IssuerProvider ({}): {} certs loaded",
self.exporter_url,
certs.len()
);

Ok(certs)
}
}

#[allow(clippy::significant_drop_tightening)]
#[async_trait]
impl Import for CertificatesImporter {
async fn import(&self) -> Result<Vec<Package>, Error> {
// Return result from cache if available
let now = Instant::now();
let mut cache = self.cache.lock().await;
if cache.updated_at >= now.checked_sub(CACHE_TTL).unwrap() {
return Ok(cache.packages.clone());
}
let mut req = Request::new(Method::GET, self.exporter_url.clone());
*req.timeout_mut() = Some(Duration::from_secs(30));

let req = Request::new(Method::GET, self.exporter_url.clone());
let response = self
.http_client
.execute(req)
.await
.context("failed to make http request")?;

if response.status() != StatusCode::OK {
return Err(anyhow!(format!("request failed: {}", response.status())).into());
return Err(anyhow!("incorrect response code: {}", response.status()).into());
}

let bs = response
.bytes()
.await
.context("failed to consume response")?
.context("failed to fetch response body")?
.to_vec();

let pkgs: Vec<Package> =
serde_json::from_slice(&bs).context("failed to parse json body")?;

cache.packages.clone_from(&pkgs);
cache.updated_at = now;
Ok(pkgs)
}
}

#[async_trait]
impl Run for CertificatesImporter {
async fn run(&self, token: CancellationToken) -> Result<(), anyhow::Error> {
let mut interval = tokio::time::interval(self.poll_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

loop {
select! {
biased;

() = token.cancelled() => {
warn!("{self:?}: exiting");
return Ok(());
},

_ = interval.tick() => {
if let Err(e) = self.refresh().await {
warn!("{self:?}: unable to refresh certificates: {e:#}");
};
}
}
}
}
}

// Wraps an importer with a verifier
// The importer imports a set of packages as usual, but then passes the packages to the verifier.
// The verifier parses out the public certificate and compares the common name to the name in the package to make sure they match.
Expand Down Expand Up @@ -222,8 +252,11 @@ mod tests {
.into())
});

let importer =
CertificatesImporter::new(Arc::new(http_client), Url::from_str("http://foo")?);
let importer = CertificatesImporter::new(
Arc::new(http_client),
Url::from_str("http://foo")?,
Duration::ZERO,
);

let out = importer.import().await?;

Expand Down
10 changes: 8 additions & 2 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,15 @@ pub async fn setup(
// Create CertIssuer providers
// It's a custom domain & cert provider at the same time.
for v in &cli.cert.cert_provider_issuer_url {
let issuer = Arc::new(providers::Issuer::new(http_client.clone(), v.clone()));
let issuer = Arc::new(providers::Issuer::new(
http_client.clone(),
v.clone(),
cli.cert.cert_provider_issuer_poll_interval,
));

cert_providers.push(issuer.clone());
custom_domain_providers.push(issuer);
custom_domain_providers.push(issuer.clone());
tasks.add(&format!("{issuer:?}"), issuer);
}

// Prepare ACME if configured
Expand Down

0 comments on commit b4df1ad

Please sign in to comment.