Skip to content

Commit

Permalink
comments and renaming in lib::tls
Browse files Browse the repository at this point in the history
  • Loading branch information
Keksoj committed Nov 14, 2023
1 parent 50afe7a commit cc12789
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 46 deletions.
4 changes: 2 additions & 2 deletions lib/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ use crate::{
server::{ListenSession, ListenToken, ProxyChannel, Server, SessionManager, SessionToken},
socket::{server_bind, FrontRustls},
timer::TimeoutContainer,
tls::{MutexWrappedCertificateResolver, ResolveCertificate, StoredCertificate},
tls::{CertifiedKeyWrapper, MutexWrappedCertificateResolver, ResolveCertificate},
util::UnwrapLog,
AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError,
ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed,
Expand Down Expand Up @@ -600,7 +600,7 @@ impl L7ListenerHandler for HttpsListener {
impl ResolveCertificate for HttpsListener {
type Error = ListenerError;

fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<StoredCertificate> {
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
let resolver = self
.resolver
.0
Expand Down
1 change: 1 addition & 0 deletions lib/src/router/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ fn find_last_dot(input: &[u8]) -> Option<usize> {
(0..input.len()).rev().find(|&i| input[i] == b'.')
}

/// A custom implementation of the [Trie data structure](https://www.wikiwand.com/en/Trie)
#[derive(Debug, PartialEq)]
pub struct TrieNode<V> {
key_value: Option<KeyValue<Key, V>>,
Expand Down
87 changes: 43 additions & 44 deletions lib/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::{
collections::{HashMap, HashSet},
convert::From,
io::BufReader,
str::FromStr,
sync::{Arc, Mutex},
str::FromStr
};

use once_cell::sync::Lazy;
Expand Down Expand Up @@ -42,7 +42,7 @@ static DEFAULT_CERTIFICATE: Lazy<Option<Arc<CertifiedKey>>> = Lazy::new(|| {

CertificateResolver::parse(&certificate_and_key)
.ok()
.map(|c| c.certified_key)
.map(|c| c.inner)
});

// -----------------------------------------------------------------------------
Expand All @@ -52,7 +52,7 @@ pub trait ResolveCertificate {
type Error;

/// return the certificate in both a Rustls-usable form, and the pem format
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<StoredCertificate>;
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper>;

/// persist a certificate, after ensuring validity, and checking if it can replace another certificate
fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error>;
Expand Down Expand Up @@ -86,6 +86,7 @@ pub trait ResolveCertificate {
// -----------------------------------------------------------------------------
// CertificateOverride struct

/// Enables use of certificates for more domain names
#[derive(Clone, Debug)]
pub struct CertificateOverride {
pub names: Option<HashSet<String>>,
Expand All @@ -110,14 +111,14 @@ impl From<&AddCertificate> for CertificateOverride {
/// [`CertifiedKey` type](https://docs.rs/rustls/latest/rustls/sign/struct.CertifiedKey.html),
/// stored and returned by the certificate resolver.
#[derive(Clone)]
pub struct StoredCertificate {
certified_key: Arc<CertifiedKey>,
pub struct CertifiedKeyWrapper {
inner: Arc<CertifiedKey>,
}

impl StoredCertificate {
impl CertifiedKeyWrapper {
/// bytes of the pem formatted certificate, first of the chain
fn pem_bytes(&self) -> &[u8] {
&self.certified_key.cert[0].0
&self.inner.cert[0].0
}
}

Expand Down Expand Up @@ -148,25 +149,28 @@ impl From<CertificateError> for CertificateResolverError {
/// Parses and stores TLS certificates, makes them available to Rustls for TLS handshakes
#[derive(Default)]
pub struct CertificateResolver {
/// all fingerprints of all
pub domains: TrieNode<Fingerprint>,
/// a map of fingerprint -> stored_certificate
certificates: HashMap<Fingerprint, StoredCertificate>,
certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
/// map of domain_name -> all fingerprints linked to this domain name
name_fingerprint_idx: HashMap<String, HashSet<Fingerprint>>,
/// map of fingerprint -> domain names to override
overrides: HashMap<Fingerprint, CertificateOverride>,
}

impl ResolveCertificate for CertificateResolver {
type Error = CertificateResolverError;

fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<StoredCertificate> {
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
self.certificates.get(fingerprint).map(ToOwned::to_owned)
}

fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error> {
// Check if we could parse the certificate, chain and private key, if not just throw an
// error.
let stored_certificate = Self::parse(&opts.certificate)?;
let fingerprint = fingerprint(stored_certificate.pem_bytes());
let certificate_to_add = Self::parse(&opts.certificate)?;
let fingerprint = fingerprint(certificate_to_add.pem_bytes());
if !opts.certificate.names.is_empty() || opts.expired_at.is_some() {
self.overrides
.insert(fingerprint.to_owned(), CertificateOverride::from(opts));
Expand All @@ -175,25 +179,26 @@ impl ResolveCertificate for CertificateResolver {
}

let (should_insert, certificates_to_remove) =
self.should_insert(&fingerprint, &stored_certificate)?;
self.should_insert(&fingerprint, &certificate_to_add)?;
if !should_insert {
// if we do not need to insert the fingerprint just return the fingerprint
return Ok(fingerprint);
}

let new_names = match self.get_names_override(&fingerprint) {
Some(names) => names,
None => self.certificate_names(stored_certificate.pem_bytes())?,
None => self.certificate_names(certificate_to_add.pem_bytes())?,
};

self.certificates
.insert(fingerprint.to_owned(), stored_certificate);
for name in new_names {
.insert(fingerprint.to_owned(), certificate_to_add);

for new_name in new_names {
self.domains
.insert(name.to_owned().into_bytes(), fingerprint.to_owned());
.insert(new_name.to_owned().into_bytes(), fingerprint.to_owned());

self.name_fingerprint_idx
.entry(name)
.entry(new_name)
.or_insert_with(HashSet::new)
.insert(fingerprint.to_owned());
}
Expand All @@ -212,10 +217,10 @@ impl ResolveCertificate for CertificateResolver {
}

fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> {
if let Some(certified_key_and_pem) = self.get_certificate(fingerprint) {
if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
let names = match self.get_names_override(fingerprint) {
Some(names) => names,
None => self.certificate_names(certified_key_and_pem.pem_bytes())?,
None => self.certificate_names(certificate_to_remove.pem_bytes())?,
};

for name in &names {
Expand Down Expand Up @@ -260,23 +265,26 @@ impl CertificateResolver {

/// return the hashset of subjects that the certificate is able to handle, by
/// parsing the pem file and scrapping the information
fn certificate_names(&self, pem: &[u8]) -> Result<HashSet<String>, CertificateResolverError> {
let fingerprint = fingerprint(pem);
fn certificate_names(
&self,
pem_bytes: &[u8],
) -> Result<HashSet<String>, CertificateResolverError> {
let fingerprint = fingerprint(pem_bytes);
if let Some(certificate_override) = self.overrides.get(&fingerprint) {
if let Some(names) = &certificate_override.names {
return Ok(names.to_owned());
}
}

get_cn_and_san_attributes(pem)
get_cn_and_san_attributes(pem_bytes)
.map_err(CertificateResolverError::InvalidCommonNameAndSubjectAlternateNames)
}

/// Parse a raw certificate into the Rustls format.
/// Parses RSA and ECDSA certificates.
fn parse(
certificate_and_key: &CertificateAndKey,
) -> Result<StoredCertificate, CertificateResolverError> {
) -> Result<CertifiedKeyWrapper, CertificateResolverError> {
let certificate_pem =
sozu_command::certificate::parse_pem(certificate_and_key.certificate.as_bytes())?;

Expand Down Expand Up @@ -304,8 +312,8 @@ impl CertificateResolver {
};
match rustls::sign::any_supported_type(&private_key) {
Ok(signing_key) => {
let stored_certificate = StoredCertificate {
certified_key: Arc::new(CertifiedKey::new(chain, signing_key)),
let stored_certificate = CertifiedKeyWrapper {
inner: Arc::new(CertifiedKey::new(chain, signing_key)),
};
Ok(stored_certificate)
}
Expand All @@ -320,14 +328,14 @@ impl CertificateResolver {
fn should_insert(
&self,
fingerprint: &Fingerprint,
stored_certificate: &StoredCertificate,
candidate_cert: &CertifiedKeyWrapper,
) -> Result<(bool, HashMap<Fingerprint, HashSet<String>>), CertificateResolverError> {
let x509 = parse_x509(stored_certificate.pem_bytes())?;
let x509 = parse_x509(candidate_cert.pem_bytes())?;

// We need to know if the new certificate can replace an already existing one.
let new_names = match self.get_names_override(fingerprint) {
Some(names) => names,
None => self.certificate_names(stored_certificate.pem_bytes())?,
None => self.certificate_names(candidate_cert.pem_bytes())?,
};

let expiration = self
Expand All @@ -337,8 +345,8 @@ impl CertificateResolver {
let fingerprints = self.find_certificates_by_names(&new_names)?;
let mut certificates = HashMap::new();
for fingerprint in &fingerprints {
if let Some(certified_key_and_pem) = self.get_certificate(fingerprint) {
certificates.insert(fingerprint, certified_key_and_pem);
if let Some(cert) = self.get_certificate(fingerprint) {
certificates.insert(fingerprint, cert);
}
}

Expand Down Expand Up @@ -442,7 +450,7 @@ impl ResolvesServerCert for MutexWrappedCertificateResolver {
let cert = resolver
.certificates
.get(fingerprint)
.map(|cert| cert.certified_key.clone());
.map(|cert| cert.inner.clone());

trace!("Found for fingerprint {}: {}", fingerprint, cert.is_some());
return cert;
Expand Down Expand Up @@ -507,10 +515,7 @@ mod tests {
if !resolver.find_certificates_by_names(&names)?.is_empty()
&& resolver.get_certificate(&fingerprint).is_some()
{
return Err(
"We have retrieve the certificate that should be deleted"
.into(),
);
return Err("We have retrieve the certificate that should be deleted".into());
}

Ok(())
Expand All @@ -530,7 +535,7 @@ mod tests {
let pem = parse_pem(certificate_and_key.certificate.as_bytes())?;

let fingerprint = resolver.add_certificate(&AddCertificate {
address: address,
address,
certificate: certificate_and_key,
expired_at: None,
})?;
Expand All @@ -544,10 +549,7 @@ mod tests {
if resolver.find_certificates_by_names(&lolcat)?.is_empty()
|| resolver.get_certificate(&fingerprint).is_none()
{
return Err(
"failed to retrieve certificate with custom names"
.into(),
);
return Err("failed to retrieve certificate with custom names".into());
}

if let Err(err) = resolver.remove_certificate(&fingerprint) {
Expand All @@ -558,10 +560,7 @@ mod tests {
if !resolver.find_certificates_by_names(&names)?.is_empty()
&& resolver.get_certificate(&fingerprint).is_some()
{
return Err(
"We have retrieve the certificate that should be deleted"
.into(),
);
return Err("We have retrieve the certificate that should be deleted".into());
}

Ok(())
Expand Down

0 comments on commit cc12789

Please sign in to comment.