diff --git a/Cargo.lock b/Cargo.lock index a9595b045a..54a65bccdf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -895,7 +895,6 @@ version = "0.1.0" dependencies = [ "thiserror", "untrusted", - "webpki", ] [[package]] diff --git a/linkerd/app/gateway/src/gateway.rs b/linkerd/app/gateway/src/gateway.rs index e7947216c7..3678b48004 100644 --- a/linkerd/app/gateway/src/gateway.rs +++ b/linkerd/app/gateway/src/gateway.rs @@ -159,7 +159,7 @@ where { if let Some(by) = fwd_by(forwarded) { tracing::info!(%forwarded); - if by == local_id.as_ref() { + if by == local_id.as_str() { return Box::pin(future::err(GatewayLoop.into())); } } diff --git a/linkerd/app/inbound/src/http/set_identity_header.rs b/linkerd/app/inbound/src/http/set_identity_header.rs index fcd5ca50ff..9db9f3014b 100644 --- a/linkerd/app/inbound/src/http/set_identity_header.rs +++ b/linkerd/app/inbound/src/http/set_identity_header.rs @@ -43,7 +43,7 @@ where .and_then(|tls| match tls { tls::ServerTls::Established { client_id, .. } => { client_id.as_ref().and_then(|id| { - match http::HeaderValue::from_str(id.as_ref().as_ref()) { + match http::HeaderValue::from_str(id.as_str()) { Ok(v) => Some(v), Err(error) => { tracing::warn!(%error, "identity not a valid header value"); diff --git a/linkerd/app/inbound/src/policy/mod.rs b/linkerd/app/inbound/src/policy/mod.rs index c88bfe9b26..315020a359 100644 --- a/linkerd/app/inbound/src/policy/mod.rs +++ b/linkerd/app/inbound/src/policy/mod.rs @@ -144,8 +144,8 @@ impl AllowPolicy { .. }) = tls { - if identities.contains(id.as_ref()) - || suffixes.iter().any(|s| s.contains(id.as_ref())) + if identities.contains(id.as_str()) + || suffixes.iter().any(|s| s.contains(id.as_str())) { return Ok(Permit::new(self.dst, &*server, authz)); } diff --git a/linkerd/app/outbound/src/http/require_id_header.rs b/linkerd/app/outbound/src/http/require_id_header.rs index a25c2f9609..bc9c431039 100644 --- a/linkerd/app/outbound/src/http/require_id_header.rs +++ b/linkerd/app/outbound/src/http/require_id_header.rs @@ -86,7 +86,7 @@ where if let Some(require_id) = Self::extract_id(&mut request) { match self.tls.as_ref() { Conditional::Some(tls::ClientTls { server_id, .. }) => { - if require_id != *server_id.as_ref() { + if require_id != **server_id { debug!( required = %require_id, found = %server_id, diff --git a/linkerd/dns/name/Cargo.toml b/linkerd/dns/name/Cargo.toml index d3215b1739..0a9bbb3c43 100644 --- a/linkerd/dns/name/Cargo.toml +++ b/linkerd/dns/name/Cargo.toml @@ -9,4 +9,3 @@ publish = false [dependencies] thiserror = "1.0" untrusted = "0.7" -webpki = "0.21" diff --git a/linkerd/dns/name/src/lib.rs b/linkerd/dns/name/src/lib.rs index 9aff437180..f0222e25cb 100644 --- a/linkerd/dns/name/src/lib.rs +++ b/linkerd/dns/name/src/lib.rs @@ -4,5 +4,5 @@ mod name; mod suffix; -pub use self::name::{InvalidName, Name}; +pub use self::name::{InvalidName, Name, NameRef}; pub use self::suffix::Suffix; diff --git a/linkerd/dns/name/src/name.rs b/linkerd/dns/name/src/name.rs index add757da31..7cb1ecd748 100644 --- a/linkerd/dns/name/src/name.rs +++ b/linkerd/dns/name/src/name.rs @@ -1,88 +1,319 @@ -use std::convert::TryFrom; -use std::fmt; +// Based on https://github.com/briansmith/webpki/blob/18cda8a5e32dfc2723930018853a984bd634e667/src/name/dns_name.rs +// +// Copyright 2015-2020 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR +// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +use std::{fmt, ops::Deref}; use thiserror::Error; -/// A `Name` is guaranteed to be syntactically valid. The validity rules +/// A DNS Name suitable for use in the TLS Server Name Indication (SNI) +/// extension and/or for use as the reference hostname for which to verify a +/// certificate. +/// +/// A `Name` is guaranteed to be syntactically valid. The validity rules are +/// specified in [RFC 5280 Section 7.2], except that underscores are also +/// allowed. +/// +/// `Name` stores a copy of the input it was constructed from in a `String` +/// and so it is only available when the `std` default feature is enabled. +/// +/// [RFC 5280 Section 7.2]: https://tools.ietf.org/html/rfc5280#section-7. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct Name(String); + +/// A reference to a DNS Name suitable for use in the TLS Server Name Indication +/// (SNI) extension and/or for use as the reference hostname for which to verify +/// a certificate. +/// +/// A `NameRef` is guaranteed to be syntactically valid. The validity rules /// are specified in [RFC 5280 Section 7.2], except that underscores are also /// allowed. -#[derive(Clone, Eq, PartialEq, Hash)] -pub struct Name(webpki::DNSName); +/// +/// [RFC 5280 Section 7.2]: https://tools.ietf.org/html/rfc5280#section-7.2 +#[derive(Clone, Copy, Debug, Eq, Hash)] +pub struct NameRef<'a>(&'a str); #[derive(Copy, Clone, Debug, Eq, PartialEq, Error)] #[error("invalid DNS name")] pub struct InvalidName; +// === impl Name === + impl Name { + /// Constructs a `Name` if the input is a syntactically-valid DNS name. + #[inline] + pub fn try_from_ascii(n: &[u8]) -> Result { + let n = NameRef::try_from_ascii(n)?; + Ok(n.to_owned()) + } + + #[inline] + pub fn as_ref(&self) -> NameRef<'_> { + NameRef(self.0.as_str()) + } + + #[inline] + pub fn as_str(&self) -> &str { + self.0.as_str() + } + + #[inline] + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() + } + #[inline] pub fn is_localhost(&self) -> bool { - self.as_ref().eq_ignore_ascii_case("localhost.") + self.as_str().eq_ignore_ascii_case("localhost.") } #[inline] pub fn without_trailing_dot(&self) -> &str { - self.as_ref().trim_end_matches('.') + self.as_str().trim_end_matches('.') } } -impl fmt::Debug for Name { +impl fmt::Display for Name { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - let s: &str = AsRef::::as_ref(&self.0); - s.fmt(f) + self.0.fmt(f) } } -impl fmt::Display for Name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - let s: &str = AsRef::::as_ref(&self.0); - s.fmt(f) +impl std::str::FromStr for Name { + type Err = InvalidName; + + #[inline] + fn from_str(n: &str) -> Result { + Self::try_from_ascii(n.as_bytes()) } } -impl From for Name { - fn from(n: webpki::DNSName) -> Self { - Name(n) +impl Deref for Name { + type Target = str; + + #[inline] + fn deref(&self) -> &str { + self.0.as_str() } } -impl<'a> TryFrom<&'a [u8]> for Name { - type Error = InvalidName; - fn try_from(s: &[u8]) -> Result { - webpki::DNSNameRef::try_from_ascii(s) - .map_err(|_invalid| InvalidName) - .map(|r| r.to_owned().into()) +// === impl NameRef === + +impl<'a> NameRef<'a> { + /// Constructs a `NameRef` from the given input if the input is a + /// syntactically-valid DNS name. + pub fn try_from_ascii(dns_name: &'a [u8]) -> Result { + if !is_valid_reference_dns_id(untrusted::Input::from(dns_name)) { + return Err(InvalidName); + } + + let s = std::str::from_utf8(dns_name).map_err(|_| InvalidName)?; + Ok(Self(s)) } -} -impl std::str::FromStr for Name { - type Err = InvalidName; - fn from_str(s: &str) -> Result { - Self::try_from(s.as_bytes()) + pub fn try_from_ascii_str(n: &'a str) -> Result { + Self::try_from_ascii(n.as_bytes()) + } + + /// Constructs a `Name` from this `NameRef` + pub fn to_owned(self) -> Name { + // NameRef is already guaranteed to be valid ASCII, which is a + // subset of UTF-8. + Name(self.as_str().to_ascii_lowercase()) + } + + #[inline] + pub fn as_str(&self) -> &str { + self.0 + } + + #[inline] + pub fn as_bytes(&self) -> &[u8] { + self.0.as_bytes() } } -impl From for webpki::DNSName { - fn from(Name(name): Name) -> webpki::DNSName { - name +impl<'a> PartialEq> for NameRef<'_> { + fn eq(&self, other: &NameRef<'a>) -> bool { + self.0.eq_ignore_ascii_case(other.0) } } -impl<'t> From<&'t Name> for webpki::DNSNameRef<'t> { - fn from(Name(ref name): &'t Name) -> webpki::DNSNameRef<'t> { - name.as_ref() +impl fmt::Display for NameRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + self.as_str().fmt(f) } } -impl AsRef for Name { - #[inline] - fn as_ref(&self) -> &str { - >::as_ref(&self.0) +// === Helpers === + +fn is_valid_reference_dns_id(hostname: untrusted::Input<'_>) -> bool { + is_valid_dns_id(hostname) +} + +// https://tools.ietf.org/html/rfc5280#section-4.2.1.6: +// +// When the subjectAltName extension contains a domain name system +// label, the domain name MUST be stored in the dNSName (an IA5String). +// The name MUST be in the "preferred name syntax", as specified by +// Section 3.5 of [RFC1034] and as modified by Section 2.1 of +// [RFC1123]. +// +// https://bugzilla.mozilla.org/show_bug.cgi?id=1136616: As an exception to the +// requirement above, underscores are also allowed in names for compatibility. +fn is_valid_dns_id(hostname: untrusted::Input<'_>) -> bool { + // https://blogs.msdn.microsoft.com/oldnewthing/20120412-00/?p=7873/ + if hostname.len() > 253 { + return false; } + + let mut input = untrusted::Reader::new(hostname); + + let mut label_length = 0; + let mut label_is_all_numeric = false; + let mut label_ends_with_hyphen = false; + + loop { + const MAX_LABEL_LENGTH: usize = 63; + + match input.read_byte() { + Ok(b'-') => { + if label_length == 0 { + return false; // Labels must not start with a hyphen. + } + label_is_all_numeric = false; + label_ends_with_hyphen = true; + label_length += 1; + if label_length > MAX_LABEL_LENGTH { + return false; + } + } + + Ok(b'0'..=b'9') => { + if label_length == 0 { + label_is_all_numeric = true; + } + label_ends_with_hyphen = false; + label_length += 1; + if label_length > MAX_LABEL_LENGTH { + return false; + } + } + + Ok(b'a'..=b'z') | Ok(b'A'..=b'Z') | Ok(b'_') => { + label_is_all_numeric = false; + label_ends_with_hyphen = false; + label_length += 1; + if label_length > MAX_LABEL_LENGTH { + return false; + } + } + + Ok(b'.') => { + if label_ends_with_hyphen { + return false; // Labels must not end with a hyphen. + } + if label_length == 0 { + return false; + } + label_length = 0; + } + + _ => { + return false; + } + } + + if input.at_end() { + break; + } + } + + if label_ends_with_hyphen { + return false; // Labels must not end with a hyphen. + } + + if label_is_all_numeric { + return false; // Last label must not be all numeric. + } + + true } #[cfg(test)] mod tests { use super::*; - use std::str::FromStr; + + #[test] + fn test_parse() { + const CASES: &[(&str, bool)] = &[ + ("", false), + (".", false), + ("..", false), + ("...", false), + ("*", false), + ("a", true), + ("a.", true), + ("d.c.b.a", true), + ("d.c.b.a.", true), + (" d.c.b.a.", false), + ("d.c.b.a-", false), + ("*.a.", false), + (".a.", false), + ("a1", true), + ("_m.foo.bar", true), + ("m.foo.bar_", true), + ("example.com:80", false), + ("1", false), + ("1.a", true), + ("a.1", false), + ("1.2.3.4", false), + ("::1", false), + ("xn--poema-9qae5a.com.br", true), // IDN + ]; + for &(n, expected_result) in CASES { + assert!(n.parse::().is_ok() == expected_result, "{}", n); + } + } + + #[test] + fn test_eq() { + const CASES: &[(&str, &str, bool)] = &[ + ("a", "a", true), + ("a", "b", false), + ("d.c.b.a", "d.c.b.a", true), + // case sensitivity + ( + "abcdefghijklmnopqrstuvwxyz", + "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + true, + ), + ("aBc", "Abc", true), + ("a1", "A1", true), + ("example", "example", true), + ("example.", "example.", true), + ("example", "example.", false), + ("example.com", "example.com", true), + ("example.com.", "example.com.", true), + ("example.com", "example.com.", false), + ]; + for &(left, right, expected_result) in CASES { + let l = left.parse::().unwrap(); + let r = right.parse::().unwrap(); + assert_eq!(l == r, expected_result, "{:?} vs {:?}", l, r); + } + } #[test] fn test_is_localhost() { @@ -94,7 +325,7 @@ mod tests { ("localhost1.", false), // suffixed ]; for (host, expected_result) in cases { - let dns_name = Name::try_from(host.as_bytes()).unwrap(); + let dns_name = host.parse::().unwrap(); assert_eq!(dns_name.is_localhost(), *expected_result, "{:?}", dns_name) } } @@ -109,7 +340,8 @@ mod tests { ("web.svc.local.", "web.svc.local"), ]; for (host, expected_result) in cases { - let dns_name = Name::try_from(host.as_bytes()) + let dns_name = host + .parse::() .unwrap_or_else(|_| panic!("'{}' was invalid", host)); assert_eq!( dns_name.without_trailing_dot(), @@ -118,7 +350,8 @@ mod tests { dns_name ) } - assert!(Name::from_str(".").is_err()); - assert!(Name::from_str("").is_err()); + assert!(".".parse::().is_err()); + assert!("..".parse::().is_err()); + assert!("".parse::().is_err()); } } diff --git a/linkerd/dns/src/lib.rs b/linkerd/dns/src/lib.rs index b476e8c3d6..16dea3e39f 100644 --- a/linkerd/dns/src/lib.rs +++ b/linkerd/dns/src/lib.rs @@ -1,6 +1,7 @@ #![deny(warnings, rust_2018_idioms)] #![forbid(unsafe_code)] +use linkerd_dns_name::NameRef; pub use linkerd_dns_name::{InvalidName, Name, Suffix}; use linkerd_error::Error; use std::{fmt, net}; @@ -59,7 +60,7 @@ impl Resolver { /// record lookups. pub async fn resolve_addrs( &self, - name: &Name, + name: NameRef<'_>, default_port: u16, ) -> Result<(Vec, time::Sleep), Error> { match self.resolve_srv(name).await { @@ -78,24 +79,29 @@ impl Resolver { async fn resolve_a( &self, - name: &Name, + name: NameRef<'_>, ) -> Result<(Vec, time::Sleep), ResolveError> { debug!(%name, "resolve_a"); - let lookup = self.dns.lookup_ip(name.as_ref()).await?; + let lookup = self.dns.lookup_ip(name.as_str()).await?; let valid_until = Instant::from_std(lookup.valid_until()); let ips = lookup.iter().collect::>(); Ok((ips, time::sleep_until(valid_until))) } - async fn resolve_srv(&self, name: &Name) -> Result<(Vec, time::Sleep), Error> { + async fn resolve_srv( + &self, + name: NameRef<'_>, + ) -> Result<(Vec, time::Sleep), Error> { debug!(%name, "resolve_srv"); - let srv = self.dns.srv_lookup(name.as_ref()).await?; + let srv = self.dns.srv_lookup(name.as_str()).await?; + let valid_until = Instant::from_std(srv.as_lookup().valid_until()); let addrs = srv .into_iter() .map(Self::srv_to_socket_addr) .collect::>()?; debug!(ttl = ?valid_until - time::Instant::now(), ?addrs); + Ok((addrs, time::sleep_until(valid_until))) } @@ -184,7 +190,7 @@ mod tests { for case in VALID { let name = Name::from_str(case.input); - assert_eq!(name.as_ref().map(|x| x.as_ref()), Ok(case.output)); + assert_eq!(name.as_deref(), Ok(case.output)); } static INVALID: &[&str] = &[ @@ -239,7 +245,7 @@ mod tests { #[cfg(fuzzing)] pub mod fuzz_logic { use super::*; - use std::str::FromStr; + pub struct FuzzConfig {} // Empty config resolver that we can use. @@ -249,11 +255,11 @@ pub mod fuzz_logic { // Test the resolvers do not panic unexpectedly. pub async fn fuzz_entry(fuzz_data: &str) { - if let Ok(name) = Name::from_str(fuzz_data) { + if let Ok(name) = fuzz_data.parse::() { let fcon = FuzzConfig {}; let resolver = Resolver::from_system_config_with(&fcon).unwrap(); - let _w = resolver.resolve_a(&name).await; - let _w2 = resolver.resolve_srv(&name).await; + let _w = resolver.resolve_a(name.as_ref()).await; + let _w2 = resolver.resolve_srv(name.as_ref()).await; } } } diff --git a/linkerd/identity/src/lib.rs b/linkerd/identity/src/lib.rs index 04104d3974..2365f60438 100644 --- a/linkerd/identity/src/lib.rs +++ b/linkerd/identity/src/lib.rs @@ -4,7 +4,7 @@ pub use ring::error::KeyRejected; use ring::rand; use ring::signature::EcdsaKeyPair; -use std::{convert::TryFrom, fmt, fs, io, str::FromStr, sync::Arc, time::SystemTime}; +use std::{fmt, fs, io, ops::Deref, str::FromStr, sync::Arc, time::SystemTime}; use thiserror::Error; use tokio_rustls::rustls; use tracing::{debug, warn}; @@ -137,17 +137,11 @@ impl From for Name { } } -impl<'t> From<&'t LocalId> for webpki::DNSNameRef<'t> { - fn from(LocalId(ref name): &'t LocalId) -> webpki::DNSNameRef<'t> { - name.into() - } -} - impl FromStr for Name { type Err = InvalidName; fn from_str(s: &str) -> Result { - if s.as_bytes().last() == Some(&b'.') { + if s.ends_with('.') { return Err(InvalidName); // SNI hostnames are implicitly absolute. } @@ -155,27 +149,11 @@ impl FromStr for Name { } } -impl TryFrom<&[u8]> for Name { - type Error = InvalidName; - - fn try_from(s: &[u8]) -> Result { - if s.last() == Some(&b'.') { - return Err(InvalidName); // SNI hostnames are implicitly absolute. - } +impl Deref for Name { + type Target = linkerd_dns_name::Name; - linkerd_dns_name::Name::try_from(s).map(|n| Name(Arc::new(n))) - } -} - -impl<'t> From<&'t Name> for webpki::DNSNameRef<'t> { - fn from(Name(ref name): &'t Name) -> webpki::DNSNameRef<'t> { - name.as_ref().into() - } -} - -impl AsRef for Name { - fn as_ref(&self) -> &str { - (*self.0).as_ref() + fn deref(&self) -> &Self::Target { + &self.0 } } @@ -249,6 +227,9 @@ impl TrustAnchors { pub fn certify(&self, key: Key, crt: Crt) -> Result { let mut client = self.0.as_ref().clone(); + let crt_id = webpki::DNSNameRef::try_from_ascii(crt.id.as_bytes()) + .expect("certificate ID must be a valid DNS name"); + // Ensure the certificate is valid for the services we terminate for // TLS. This assumes that server cert validation does the same or // more validation than client cert validation. @@ -260,11 +241,11 @@ impl TrustAnchors { // XXX: Once `rustls::ServerCertVerified` is exposed in Rustls's // safe API, use it to pass proof to CertCertResolver::new.... // - // TODO: Restrict accepted signatutre algorithms. + // TODO: Restrict accepted signature algorithms. static NO_OCSP: &[u8] = &[]; client .get_verifier() - .verify_server_cert(&client.root_store, &crt.chain, (&crt.id).into(), NO_OCSP) + .verify_server_cert(&client.root_store, &crt.chain, crt_id, NO_OCSP) .map_err(InvalidCrt)?; debug!("certified {}", crt.id); @@ -326,7 +307,7 @@ impl Crt { } pub fn name(&self) -> &Name { - self.id.as_ref() + &*self.id } } @@ -340,7 +321,7 @@ impl From<&'_ Crt> for LocalId { impl CrtKey { pub fn name(&self) -> &Name { - self.id.as_ref() + &*self.id } pub fn expiry(&self) -> SystemTime { @@ -442,8 +423,10 @@ impl From for Name { } } -impl AsRef for LocalId { - fn as_ref(&self) -> &Name { +impl Deref for LocalId { + type Target = Name; + + fn deref(&self) -> &Name { &self.0 } } diff --git a/linkerd/proxy/dns-resolve/src/lib.rs b/linkerd/proxy/dns-resolve/src/lib.rs index ed85287e86..88008cc983 100644 --- a/linkerd/proxy/dns-resolve/src/lib.rs +++ b/linkerd/proxy/dns-resolve/src/lib.rs @@ -70,7 +70,7 @@ async fn resolution(dns: dns::Resolver, na: NameAddr) -> Result Result { debug!(?addrs, name = %na); let eps = addrs.into_iter().map(|a| (a, ())).collect(); diff --git a/linkerd/proxy/identity/src/certify.rs b/linkerd/proxy/identity/src/certify.rs index ae814b6cee..c765db53aa 100644 --- a/linkerd/proxy/identity/src/certify.rs +++ b/linkerd/proxy/identity/src/certify.rs @@ -204,7 +204,7 @@ impl LocalCrtKey { } pub fn name(&self) -> &id::Name { - self.id.as_ref() + &*self.id } pub fn client_config(&self) -> tls::client::Config { diff --git a/linkerd/tls/src/client.rs b/linkerd/tls/src/client.rs index ba051a4cb0..5bb99d3101 100644 --- a/linkerd/tls/src/client.rs +++ b/linkerd/tls/src/client.rs @@ -9,6 +9,7 @@ use linkerd_stack::{layer, Param}; use std::{ fmt, future::Future, + ops::Deref, pin::Pin, str::FromStr, sync::Arc, @@ -146,7 +147,9 @@ where let connect = self.inner.call(target); Either::Right(Box::pin(async move { let io = connect.await?; - let io = handshake.connect((&server_id.0).into(), io).await?; + let sni = webpki::DNSNameRef::try_from_ascii(server_id.as_bytes()) + .expect("identity must be a valid DNS-like name"); + let io = handshake.connect(sni, io).await?; if let Some(alpn) = io.get_ref().1.get_alpn_protocol() { debug!(alpn = ?std::str::from_utf8(alpn)); } @@ -169,8 +172,10 @@ impl From for id::Name { } } -impl AsRef for ServerId { - fn as_ref(&self) -> &id::Name { +impl Deref for ServerId { + type Target = id::Name; + + fn deref(&self) -> &id::Name { &self.0 } } diff --git a/linkerd/tls/src/server/client_hello.rs b/linkerd/tls/src/server/client_hello.rs index 393e58466b..bc77aa2988 100644 --- a/linkerd/tls/src/server/client_hello.rs +++ b/linkerd/tls/src/server/client_hello.rs @@ -1,27 +1,23 @@ use crate::ServerId; use linkerd_identity as id; -use std::convert::TryFrom; use tracing::trace; #[derive(Debug, Eq, PartialEq)] pub struct Incomplete; -/// Determintes whether the given `input` looks like the start of a TLS -/// connection. +/// Determines whether the given `input` looks like the start of a TLS connection. /// -/// The determination is made based on whether the input looks like (the start -/// of) a valid ClientHello that a reasonable TLS client might send, and the -/// SNI matches the given identity. +/// The determination is made based on whether the input looks like (the start of) a valid +/// ClientHello that a reasonable TLS client might send, and the SNI matches the given identity. /// -/// XXX: Once the TLS record header is matched, the determination won't be -/// made until the entire TLS record including the entire ClientHello handshake -/// message is available. +/// XXX: Once the TLS record header is matched, the determination won't be made until the entire TLS +/// record including the entire ClientHello handshake message is available. /// /// TODO: Reject non-matching inputs earlier. /// -/// This assumes that the ClientHello is small and is sent in a single TLS -/// record, which is what all reasonable implementations do. (If they were not -/// to, they wouldn't interoperate with picky servers.) +/// This assumes that the ClientHello is small and is sent in a single TLS record, which is what all +/// reasonable implementations do. (If they were not to, they wouldn't interoperate with picky +/// servers.) pub fn parse_sni(input: &[u8]) -> Result, Incomplete> { let r = untrusted::Input::from(input).read_all(untrusted::EndOfInput, |input| { let r = extract_sni(input); @@ -30,11 +26,15 @@ pub fn parse_sni(input: &[u8]) -> Result, Incomplete> { }); match r { Ok(Some(sni)) => { - let sni = id::Name::try_from(sni.as_slice_less_safe()) + let sni = match std::str::from_utf8(sni.as_slice_less_safe()) .ok() - .map(ServerId); + .and_then(|n| n.parse::().ok()) + { + Some(sni) => sni, + None => return Ok(None), + }; trace!(?sni, "parse_sni: parsed correctly up to SNI"); - Ok(sni) + Ok(Some(ServerId(sni))) } Ok(None) => { trace!("parse_sni: failed to parse up to SNI"); diff --git a/linkerd/tls/src/server/mod.rs b/linkerd/tls/src/server/mod.rs index 437ca7d0cd..cae1f31ca1 100644 --- a/linkerd/tls/src/server/mod.rs +++ b/linkerd/tls/src/server/mod.rs @@ -11,6 +11,7 @@ use linkerd_io::{self as io, AsyncReadExt, EitherIo, PrefixedIo}; use linkerd_stack::{layer, ExtractParam, InsertParam, NewService, Param}; use std::{ fmt, + ops::Deref, pin::Pin, str::FromStr, sync::Arc, @@ -304,7 +305,12 @@ fn client_identity(tls: &TlsStream) -> Option { match dns_names.first()? { GeneralDNSNameRef::DNSName(n) => { - Some(ClientId(id::Name::from(dns::Name::from(n.to_owned())))) + // Unfortunately we have to allocate a new string here, since there's no way to get the + // underlying bytes from a `DNSNameRef`. + let name = AsRef::::as_ref(&n.to_owned()) + .parse::() + .ok()?; + Some(ClientId(name.into())) } GeneralDNSNameRef::Wildcard(_) => { // Wildcards can perhaps be handled in a future path... @@ -327,8 +333,10 @@ impl From for id::Name { } } -impl AsRef for ClientId { - fn as_ref(&self) -> &id::Name { +impl Deref for ClientId { + type Target = id::Name; + + fn deref(&self) -> &id::Name { &self.0 } }