diff --git a/Cargo.toml b/Cargo.toml index 57840e7e..54f1b925 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,15 +26,20 @@ tempfile = "3.1.0" [target.'cfg(target_os = "windows")'.dependencies] schannel = "0.1.17" -[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios", target_os = "espidf")))'.dependencies] +[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios", target_env = "sgx")))'.dependencies] log = "0.4.5" openssl = "0.10.29" openssl-sys = "0.9.55" openssl-probe = "0.1" -[target.'cfg(target_os = "espidf")'.dependencies] -mbedtls = { version = "0.8.1", features = ["pkcs12", "std"], path = "/home/mabez/development/rust/embedded/util/rust-mbedtls/mbedtls" } +[target.'cfg(target_env = "sgx")'.dependencies] +mbedtls = { version = "0.9", features = ["std", "rdrand", "mpi_force_c_code" ], default-features = false } +pkcs5 = { version = "0.7.1", features = ["alloc", "pbes2"] } +p12 = "0.6.3" +yasna = "0.5" [dev-dependencies] +lazy_static = "1.4.0" tempfile = "3.0" test-cert-gen = "0.9" +ureq = "2.6" diff --git a/src/imp/mbedtls.rs b/src/imp/mbedtls.rs index 26f045dc..1dcf87bb 100644 --- a/src/imp/mbedtls.rs +++ b/src/imp/mbedtls.rs @@ -1,88 +1,56 @@ extern crate mbedtls; -use self::mbedtls::ssl::context::IoCallback; - use self::mbedtls::alloc::{Box as MbedtlsBox, List as MbedtlsList}; use self::mbedtls::hash::{Md, Type as MdType}; use self::mbedtls::pk::Pk; -use self::mbedtls::pkcs12::{Pfx, Pkcs12Error}; -use self::mbedtls::rng::{CtrDrbg, OsEntropy}; +use self::mbedtls::rng::{CtrDrbg, Rdseed}; +use self::mbedtls::ssl::config::NullTerminatedStrList; use self::mbedtls::ssl::config::{Endpoint, Preset, Transport}; use self::mbedtls::ssl::{Config, Context, Version}; use self::mbedtls::x509::certificate::Certificate as MbedtlsCert; use self::mbedtls::Error as TlsError; -use self::mbedtls::Result as TlsResult; +use std::convert::TryFrom; use std::error; use std::fmt::{self, Debug}; -use std::fs; -use std::io::{self, Read}; +use std::io; use std::sync::Arc; use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; -fn load_ca_certs(dir: &str) -> TlsResult> { - let paths = fs::read_dir(dir).map_err(|_| TlsError::X509FileIoError)?; - - let mut certs = Vec::new(); - - for path in paths { - if let Ok(mut file) = fs::File::open(path.unwrap().path()) { - let mut contents = Vec::new(); - if let Ok(_) = file.read_to_end(&mut contents) { - contents.push(0); // needs NULL terminator - if let Ok(cert) = ::Certificate::from_pem(&contents) { - certs.push(cert); - } - } - } - } - - Ok(certs) -} - -fn load_system_trust_roots() -> Result, Error> { - let paths = [ - "/etc/pki/CA/certs", // Fedora, RHEL - "/usr/share/ca-certificates/mozilla", // Ubuntu, Debian, Arch, Gentoo - ]; - - for path in paths.iter() { - if let Ok(certs) = load_ca_certs(path) { - return Ok(certs); - } - } - - Err(Error::Custom( - "Could not load system default trust roots".to_owned(), - )) -} - #[derive(Debug)] pub enum Error { - Normal(TlsError), - Pkcs12(Pkcs12Error), + Tls(TlsError), + Pkcs12(yasna::ASN1Error), + Pkcs5(pkcs5::Error), + Der(pkcs5::der::Error), Custom(String), } -#[derive(Debug, Copy, Clone)] -enum ProtocolRole { - Client, - Server, -} - impl From for Error { fn from(err: TlsError) -> Error { - Error::Normal(err) + Error::Tls(err) } } -impl From for Error { - fn from(err: Pkcs12Error) -> Error { +impl From for Error { + fn from(err: yasna::ASN1Error) -> Error { Error::Pkcs12(err) } } +impl From for Error { + fn from(err: pkcs5::Error) -> Error { + Error::Pkcs5(err) + } +} + +impl From for Error { + fn from(err: pkcs5::der::Error) -> Error { + Error::Der(err) + } +} + impl From for HandshakeError { fn from(e: TlsError) -> HandshakeError { HandshakeError::Failure(e.into()) @@ -91,66 +59,166 @@ impl From for HandshakeError { impl error::Error for Error { fn source(&self) -> Option<&(dyn error::Error + 'static)> { - // error::Error::source(&self) - todo!() + match *self { + Error::Tls(ref e) => e.source(), + Error::Pkcs12(ref e) => e.source(), + Error::Pkcs5(_) => None, + Error::Der(_) => None, + Error::Custom(_) => None, + } } } impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match *self { - Error::Normal(ref e) => fmt::Display::fmt(e, fmt), + Error::Tls(ref e) => fmt::Display::fmt(e, fmt), Error::Pkcs12(ref e) => fmt::Display::fmt(e, fmt), + Error::Pkcs5(ref e) => fmt::Display::fmt(e, fmt), + Error::Der(ref e) => fmt::Display::fmt(e, fmt), Error::Custom(ref e) => fmt::Display::fmt(e, fmt), } } } -fn map_version(protocol: Option) -> Option { - if let Some(protocol) = protocol { - match protocol { - Protocol::Sslv3 => Some(Version::Ssl3), - Protocol::Tlsv10 => Some(Version::Tls1_0), - Protocol::Tlsv11 => Some(Version::Tls1_1), - Protocol::Tlsv12 => Some(Version::Tls1_2), - _ => None, - } +fn to_mbedtls_version(protocol: Protocol) -> Version { + match protocol { + Protocol::Sslv3 => Version::Ssl3, + Protocol::Tlsv10 => Version::Tls1_0, + Protocol::Tlsv11 => Version::Tls1_1, + Protocol::Tlsv12 => Version::Tls1_2, + } +} + +trait NullTerminated { + fn null_terminated(&self) -> Vec; +} + +impl> NullTerminated for T { + fn null_terminated(&self) -> Vec { + let mut buf = self.as_ref().to_vec(); + buf.push(0); + buf + } +} + +fn pkcs12_decode_key_bag>( + key_bag: &p12::EncryptedPrivateKeyInfo, + pass: B, +) -> Result, Error> { + // try to decrypt the key with algorithms supported by p12 crate + if let Some(decrypted) = key_bag.decrypt(pass.as_ref()) { + Ok(decrypted) + // try to decrypt the key with algorithms supported by pkcs5 standard + } else if let p12::AlgorithmIdentifier::OtherAlg(_) = key_bag.encryption_algorithm { + // write the algorithm identifier back to DER format + let algorithm_der = + yasna::construct_der(|writer| key_bag.encryption_algorithm.write(writer)); + // and construct pkcs5 decoder from it + let scheme = pkcs5::EncryptionScheme::try_from(&algorithm_der[..])?; + + Ok(scheme.decrypt(pass.as_ref(), &key_bag.encrypted_data)?) } else { - None + Err(Error::Custom( + "Unsupported key encryption algorithm".to_owned(), + )) } } -pub struct Identity(Pfx); +#[derive(Clone)] +pub struct Identity { + key: Arc, + certificates: Arc>, +} impl Identity { pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result { - let pkcs12 = Pfx::parse(buf).map_err(Error::Pkcs12)?; - let decrypted = pkcs12.decrypt(&pass, None).map_err(Error::Pkcs12)?; - Ok(Identity(decrypted)) + let pfx = p12::PFX::parse(buf)?; + let key = pfx + .bags(pass)? + .iter() + .find_map(|safe_bag| { + if let p12::SafeBagKind::Pkcs8ShroudedKeyBag(ref key_bag) = safe_bag.bag { + Some(pkcs12_decode_key_bag(key_bag, pass)) + } else { + None + } + }) + .ok_or(Error::Custom("No private key in pkcs12 DER".to_owned()))? + .map(|key| Pk::from_private_key(&key, Some(pass.as_bytes())))??; + let certificates: MbedtlsList<_> = pfx + .cert_bags(pass)? + .iter() + .map(|cert| MbedtlsCert::from_der(cert)) + .collect::>()?; + + if !certificates.is_empty() { + Ok(Identity { + key: Arc::new(key), + certificates: Arc::new(certificates), + }) + } else { + Err(Error::Custom( + "PKCS12 file is missing certificate chain".to_owned(), + )) + } + } + + pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result { + let key = Pk::from_private_key(&key.null_terminated(), None)?; + let certificates = MbedtlsCert::from_pem_multiple(&buf.null_terminated())?; + + if !certificates.is_empty() { + Ok(Identity { + key: Arc::new(key), + certificates: Arc::new(certificates), + }) + } else { + Err(Error::Custom( + "X509 chain file is missing certificate chain".to_owned(), + )) + } + } + + fn certificates(&self) -> Arc> { + self.certificates.clone() + } + + fn private_key(&self) -> Arc { + self.key.clone() } } -impl Clone for Identity { - fn clone(&self) -> Self { - Identity(self.0.clone()) +impl Debug for Identity { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Identity") + .field( + "certificates", + &self + .certificates + .iter() + .map(|cert| cert.as_der().to_vec()) + .collect::>(), + ) + .field( + "key_name", + &self.key.name().map(String::from).map_err(Error::Tls), + ) + .finish() } } #[derive(Clone)] pub struct Certificate(MbedtlsBox); -unsafe impl Sync for Certificate {} impl Certificate { pub fn from_der(buf: &[u8]) -> Result { - let cert = MbedtlsCert::from_der(buf).map_err(Error::Normal)?; + let cert = MbedtlsCert::from_der(buf).map_err(Error::Tls)?; Ok(Certificate(cert)) } pub fn from_pem(buf: &[u8]) -> Result { - // Mbedtls needs there to be a trailing NULL byte ... - let mut pem = buf.to_vec(); - pem.push(0); - let cert = MbedtlsCert::from_pem(&pem).map_err(Error::Normal)?; + let cert = MbedtlsCert::from_pem(&buf.null_terminated()).map_err(Error::Tls)?; Ok(Certificate(cert)) } @@ -160,63 +228,136 @@ impl Certificate { } } -fn cert_to_vec(certs_in: &[::Certificate]) -> Vec> { - certs_in.iter().map(|cert| (cert.0).0.clone()).collect() +impl Debug for Certificate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Certificate") + .field(&self.0.as_der()) + .finish() + } } -#[allow(unused)] pub struct TlsStream { - role: ProtocolRole, - ca_certs: Vec>, - ca_cert_list: Arc>, - cred_pk: Option>, - cred_certs: Vec>, - cred_cert_list: Arc>, - entropy: Arc, - rng: Arc, - config: Arc, ctx: Context, + role: Endpoint, + identity: Option, +} + +impl TlsStream { + pub fn get_ref(&self) -> &S { + self.ctx.io().expect("Not connected") + } + + pub fn get_mut(&mut self) -> &mut S { + self.ctx.io_mut().expect("Not connected") + } + + pub fn buffered_read_size(&self) -> Result { + Ok(self.ctx.bytes_available()) + } + + #[cfg(feature = "alpn")] + pub fn negotiated_alpn(&self) -> Result>, Error> { + Ok(self.ctx.get_alpn_protocol()?.map(|s| s.as_bytes().to_vec())) + } + + pub fn peer_certificate(&self) -> Result, Error> { + let cert = match self.ctx.peer_cert() { + Ok(Some(certs)) => certs.iter().next().map(|cert| Certificate(cert.clone())), + Ok(_) => None, + Err(e) => match e { + TlsError::SslBadInputData => None, + _ => return Err(Error::Tls(e)), + }, + }; + Ok(cert) + } + + fn server_certificate(&self) -> Result, Error> { + match self.role { + Endpoint::Client => self.peer_certificate(), + Endpoint::Server => match self.identity { + Some(ref idt) => Ok(idt + .certificates() + .iter() + .map(|cert| Certificate(cert.clone())) + .next()), + None => Ok(None), + }, + } + } + + pub fn tls_server_end_point(&self) -> Result>, Error> { + let cert = match self.server_certificate()? { + Some(cert) => cert, + None => return Ok(None), + }; + + let md = match cert.0.digest_type() { + MdType::Md5 | MdType::Sha1 => MdType::Sha256, + md => md, + }; + + let der = cert.to_der()?; + let mut digest = vec![0; 64]; + let len = Md::hash(md, &der, &mut digest).map_err(Error::Tls)?; + digest.truncate(len); + + Ok(Some(digest)) + } + + pub fn shutdown(&mut self) -> io::Result<()> { + self.ctx.close(); + Ok(()) + } } -unsafe impl Sync for TlsStream {} -unsafe impl Send for TlsStream {} +impl io::Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.ctx.read(buf) + } +} + +impl io::Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.ctx.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.ctx.flush() + } +} impl Debug for TlsStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TlsStream") - .field("role", &self.role) - .field("ca_certs", &self.ca_certs) - // .field("ca_cert_list", &self.ca_cert_list) - // .field("cred_pk", &self.cred_pk) - // .field("cred_certs", &self.cred_certs) - // .field("cred_cert_list", &self.cred_cert_list) - // .field("entropy", &self.entropy) - // .field("rng", &self.rng) - // .field("config", &self.config) - // .field("ctx", &self.ctx) - // .field("stream", &self.stream) + .field( + "role", + &match self.role { + Endpoint::Client => "client", + Endpoint::Server => "server", + }, + ) + .field("identity", &self.identity) .finish() } } #[derive(Debug)] -pub struct MidHandshakeTlsStream { - stream: TlsStream, - error: Error, -} +pub struct MidHandshakeTlsStream(TlsStream); pub enum HandshakeError { Failure(Error), + // this is actually unused WouldBlock(MidHandshakeTlsStream), } impl MidHandshakeTlsStream { pub fn get_ref(&self) -> &S { - self.stream.get_ref() + self.0.get_ref() } pub fn get_mut(&mut self) -> &mut S { - self.stream.get_mut() + self.0.get_mut() } } @@ -225,131 +366,90 @@ where S: io::Read + io::Write, { pub fn handshake(self) -> Result, HandshakeError> { - Ok(self.stream) + Ok(self.0) } } #[derive(Clone)] pub struct TlsConnector { - min_protocol: Option, - max_protocol: Option, - root_certificates: Vec<::Certificate>, + config: Arc, identity: Option<::Identity>, - accept_invalid_certs: bool, accept_invalid_hostnames: bool, - use_sni: bool, } impl Debug for TlsConnector { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TlsConnector") - .field("min_protocol", &self.min_protocol) - .field("max_protocol", &self.max_protocol) - // .field("root_certificates", &self.root_certificates) - // .field("identity", &self.identity) - .field("accept_invalid_certs", &self.accept_invalid_certs) + .field("identity", &self.identity.as_ref().map(|idt| &idt.0)) .field("accept_invalid_hostnames", &self.accept_invalid_hostnames) - .field("use_sni", &self.use_sni) .finish() } } impl TlsConnector { pub fn new(builder: &TlsConnectorBuilder) -> Result { - let trust_roots = if builder.root_certificates.len() > 0 { - builder.root_certificates.clone() - } else { - load_system_trust_roots()? - }; - - Ok(TlsConnector { - min_protocol: builder.min_protocol, - max_protocol: builder.max_protocol, - root_certificates: trust_roots, - identity: builder.identity.clone(), - accept_invalid_certs: builder.accept_invalid_certs, - accept_invalid_hostnames: builder.accept_invalid_hostnames, - use_sni: builder.use_sni, - }) - } - - pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> - where - S: IoCallback, - { - println!("CONNECTING IN MBETLS"); - let identity = if let Some(identity) = &self.identity { - let mut keys = (identity.0).0.private_keys().collect::>(); - let certificates = (identity.0).0.certificates().collect::>(); - - if keys.len() != 1 { - return Err(HandshakeError::Failure(Error::Custom( - "Unexpected number of keys in PKCS12 file".to_owned(), - ))); - } - if certificates.len() == 0 { - return Err(HandshakeError::Failure(Error::Custom( - "PKCS12 file is missing certificate chain".to_owned(), - ))); - } - - let mut cert_chain = vec![]; - for cert in certificates { - cert_chain.push(cert.0?); - } - - fn pk_clone(pk: &mut Pk) -> TlsResult { - let der = pk.write_private_der_vec()?; - Pk::from_private_key(&der, None) - } - let key = Box::new(keys.pop().unwrap().0.map_err(|_| TlsError::PkInvalidAlg)?); - - Some((cert_chain, key)) - } else { - None - }; - - let ca_vec = cert_to_vec(&self.root_certificates); - let mut ca_list = MbedtlsList::new(); - ca_vec.clone().into_iter().for_each(|c| ca_list.push(c)); - let ca_list = Arc::new(ca_list); - - let entropy = Arc::new(OsEntropy::new()); - let rng = Arc::new(CtrDrbg::new(entropy.clone(), None)?); let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); - config.set_rng(rng.clone()); - config.set_ca_list(ca_list.clone(), None); - - let mut cred_certs = Default::default(); - let mut cred_cert_list = Arc::new(MbedtlsList::new()); - let mut cred_pk = None; - - if let Some((certificates, mut pk)) = identity { - cred_certs = certificates.to_vec(); - let mut tmp = MbedtlsList::new(); - cred_certs.clone().into_iter().for_each(|c| tmp.push(c)); - cred_cert_list = Arc::new(tmp); - - let cpk = Arc::new(Pk::from_private_key(&pk.write_private_der_vec()?, None)?); - cred_pk = Some(cpk.clone()); - config.push_cert(cred_cert_list.clone(), cpk.clone())?; + // Set Rng + let entropy = Arc::new(Rdseed); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + config.set_rng(rng); + + // Set root certificates + let ca_list = builder + .root_certificates + .iter() + .map(|cert| (cert.0).0.clone()) + .collect(); + config.set_ca_list(Arc::new(ca_list), None); + + // Add identity certificates and key + if let Some(identity) = &builder.identity { + config.push_cert(identity.0.certificates(), identity.0.private_key())?; } - if self.accept_invalid_certs { + // Set authmode + if builder.accept_invalid_certs { config.set_authmode(mbedtls::ssl::config::AuthMode::None); } - if let Some(min_version) = map_version(self.min_protocol) { + // Set minimum protocol version + if let Some(min_version) = builder.min_protocol.map(to_mbedtls_version) { config.set_min_version(min_version)?; } - if let Some(max_version) = map_version(self.max_protocol) { + + // Set maximum protocol version + if let Some(max_version) = builder.max_protocol.map(to_mbedtls_version) { config.set_max_version(max_version)?; } - let config = Arc::new(config); - let mut ctx = Context::new(config.clone()); + #[cfg(feature = "alpn")] + { + if !builder.alpn.is_empty() { + let alpns: Vec<_> = builder + .alpn + .iter() + .map(|protocol| protocol.as_str()) + .collect(); + config.set_alpn_protocols(Arc::new(NullTerminatedStrList::new(&alpns)?))?; + } + } + + Ok(TlsConnector { + config: Arc::new(config), + identity: builder.identity.clone(), + accept_invalid_hostnames: builder.accept_invalid_hostnames, + }) + } + + pub fn connect(&self, domain: &str, stream: S) -> Result, HandshakeError> + where + S: io::Read + io::Write, + { + // Create mbedtls context + let mut ctx = Context::new(self.config.clone()); + // Establish connection let hostname = if self.accept_invalid_hostnames { None } else { @@ -359,174 +459,64 @@ impl TlsConnector { ctx.establish(stream, hostname)?; Ok(TlsStream { - role: ProtocolRole::Client, - ca_certs: ca_vec, - ca_cert_list: ca_list, - cred_pk: cred_pk, - cred_certs: cred_certs, - cred_cert_list: cred_cert_list, - entropy, - rng, - config, ctx, + role: Endpoint::Client, + identity: self.identity.clone().map(|idt| idt.0), }) } } #[derive(Clone)] pub struct TlsAcceptor { - identity: Pfx, - min_protocol: Option, - max_protocol: Option, + config: Arc, + identity: Identity, } impl TlsAcceptor { pub fn new(builder: &TlsAcceptorBuilder) -> Result { - Ok(TlsAcceptor { - identity: (builder.identity.0).0.clone(), - min_protocol: builder.min_protocol, - max_protocol: builder.max_protocol, - }) - } - - pub fn accept(&self, stream: S) -> Result, HandshakeError> - where - S: IoCallback, - { - println!("ACCEPTING IN MBETLS"); - let mut keys = self.identity.private_keys().collect::>(); - let certificates = self.identity.certificates().collect::>(); - - if keys.len() != 1 { - return Err(HandshakeError::Failure(Error::Custom( - "Unexpected number of keys in PKCS12 file".to_owned(), - ))); - } - if certificates.len() == 0 { - return Err(HandshakeError::Failure(Error::Custom( - "PKCS12 file is missing certificate chain".to_owned(), - ))); - } - - let mut cert_chain = vec![]; - for cert in certificates { - cert_chain.push(cert.0?); - } - - let key: &mut Pk = &mut keys.pop().unwrap().0.map_err(|_| TlsError::PkInvalidAlg)?; + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); - let pk = Arc::new(Pk::from_private_key(&key.write_private_der_vec()?, None)?); - let mut cert_list = MbedtlsList::new(); - cert_chain - .to_vec() - .into_iter() - .for_each(|c| cert_list.push(c)); - let cert_list = Arc::new(cert_list); + // Set Rng + let entropy = Arc::new(Rdseed); + let rng = Arc::new(CtrDrbg::new(entropy, None)?); + config.set_rng(rng); - let entropy = Arc::new(OsEntropy::new()); - let rng = Arc::new(CtrDrbg::new(entropy.clone(), None)?); - let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); - config.set_rng(rng.clone()); - config.push_cert(cert_list.clone(), pk.clone())?; + // Add identity certificates and key + config.push_cert( + builder.identity.0.certificates(), + builder.identity.0.private_key(), + )?; - if let Some(min_version) = map_version(self.min_protocol) { + // Set minimum protocol version + if let Some(min_version) = builder.min_protocol.map(to_mbedtls_version) { config.set_min_version(min_version)?; } - if let Some(max_version) = map_version(self.max_protocol) { + + // Set maximum protocol version + if let Some(max_version) = builder.max_protocol.map(to_mbedtls_version) { config.set_max_version(max_version)?; } - let config = Arc::new(config); + Ok(TlsAcceptor { + config: Arc::new(config), + identity: (builder.identity.0).clone(), + }) + } - let mut ctx = Context::new(config.clone()); + pub fn accept(&self, stream: S) -> Result, HandshakeError> + where + S: io::Read + io::Write, + { + // Create mbedtls context + let mut ctx = Context::new(self.config.clone()); + // Establish connection ctx.establish(stream, None)?; Ok(TlsStream { - role: ProtocolRole::Server, - ca_certs: Vec::new(), - ca_cert_list: Arc::new(MbedtlsList::new()), - cred_pk: Some(pk), - cred_certs: cert_chain, - cred_cert_list: cert_list, - entropy, - rng, - config, ctx, + role: Endpoint::Server, + identity: Some(self.identity.clone()), }) } } - -impl TlsStream { - pub fn get_ref(&self) -> &S { - self.ctx.io().unwrap() - } - - pub fn get_mut(&mut self) -> &mut S { - self.ctx.io_mut().unwrap() - } - - pub fn buffered_read_size(&self) -> Result { - Ok(self.ctx.bytes_available()) - } - - pub fn peer_certificate(&self) -> Result, Error> { - match self.ctx.peer_cert()? { - None => Ok(None), - Some(certs) => match certs.iter().next() { - None => Ok(None), - Some(c) => Ok(Some(Certificate::from_der(c.as_der())?)), - }, - } - } - - fn server_certificate(&self) -> Result, Error> { - match self.role { - ProtocolRole::Client => self.peer_certificate(), - ProtocolRole::Server => match self.cred_certs.first() { - None => Ok(None), - Some(c) => Ok(Some(Certificate::from_der(c.as_der())?)), - }, - } - } - - pub fn tls_server_end_point(&self) -> Result>, Error> { - let cert = match self.server_certificate()? { - Some(cert) => cert, - None => return Ok(None), - }; - - let md = match cert.0.digest_type() { - MdType::Md5 | MdType::Sha1 => MdType::Sha256, - md => md, - }; - - let der = cert.to_der()?; - let mut digest = vec![0; 64]; - let len = Md::hash(md, &der, &mut digest).map_err(Error::Normal)?; - digest.truncate(len); - - Ok(Some(digest)) - } - - pub fn shutdown(&mut self) -> io::Result<()> { - // Shutdown happens as a result of drop ... - Ok(()) - } -} - -impl io::Read for TlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.ctx.read(buf) - } -} - -impl io::Write for TlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.ctx.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.ctx.flush() - } -} diff --git a/src/lib.rs b/src/lib.rs index bc907130..80916211 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,18 +108,32 @@ use std::fmt; use std::io; use std::result; -#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))] +#[cfg(not(any( + target_os = "macos", + target_os = "windows", + target_os = "ios", + target_env = "sgx" +)))] #[macro_use] extern crate log; #[cfg(any(target_os = "macos", target_os = "ios"))] #[path = "imp/security_framework.rs"] mod imp; -#[cfg(target_os = "windows")] +#[cfg(all(target_os = "windows", not(target_env = "sgx")))] #[path = "imp/schannel.rs"] mod imp; -#[cfg(not(any(target_os = "macos", target_os = "windows", target_os = "ios")))] +#[cfg(not(any( + target_os = "macos", + target_os = "windows", + target_os = "ios", + target_os = "espidf", + target_env = "sgx" +)))] #[path = "imp/openssl.rs"] mod imp; +#[cfg(target_env = "sgx")] +#[path = "imp/mbedtls.rs"] +mod imp; #[cfg(test)] mod test; @@ -713,8 +727,10 @@ fn _check_kinds() { is_send::(); is_sync::(); is_send::(); + #[cfg(not(target_env = "sgx"))] is_sync::>(); is_send::>(); + #[cfg(not(target_env = "sgx"))] is_sync::>(); is_send::>(); } diff --git a/src/test.rs b/src/test.rs index c51b0bc4..1b68bd38 100644 --- a/src/test.rs +++ b/src/test.rs @@ -7,6 +7,38 @@ use std::thread; use super::*; +#[cfg(target_env = "sgx")] +lazy_static::lazy_static! { + static ref ROOT_CERTIFICATES: Vec = { + // except digicert just because we have to provide any exclusion to get the rest + let mut root_certs = ureq::get("https://mkcert.org/generate/all/except/digicert") + .call() + .unwrap() + .into_string() + .unwrap(); + root_certs.push('\0'); + let root_certs = mbedtls::x509::certificate::Certificate::from_pem_multiple(root_certs.as_bytes()).unwrap(); + root_certs.iter().map(|cert| Certificate::from_der(cert.as_der()).unwrap()).collect() + }; +} + +// for mbedtls there is no 'standard' way to get default ca root chain +// so for tests where some default is needed we manually add mozilla trust chain. +macro_rules! connector { + () => {{ + #[cfg(target_env = "sgx")] + { + let mut builder = TlsConnector::builder(); + ROOT_CERTIFICATES.iter().for_each(|cert| { + builder.add_root_certificate(cert.clone()); + }); + builder + } + #[cfg(not(target_env = "sgx"))] + TlsConnector::builder() + }}; +} + macro_rules! p { ($e:expr) => { match $e { @@ -18,7 +50,7 @@ macro_rules! p { #[test] fn connect_google() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); let mut socket = p!(builder.connect("google.com", s)); @@ -26,23 +58,20 @@ fn connect_google() { let mut result = vec![]; p!(socket.read_to_end(&mut result)); - println!("{}", String::from_utf8_lossy(&result)); assert!(result.starts_with(b"HTTP/1.0")); assert!(result.ends_with(b"\r\n") || result.ends_with(b"")); } #[test] fn connect_bad_hostname() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); builder.connect("goggle.com", s).unwrap_err(); } #[test] fn connect_bad_hostname_ignored() { - let builder = p!(TlsConnector::builder() - .danger_accept_invalid_hostnames(true) - .build()); + let builder = p!(connector!().danger_accept_invalid_hostnames(true).build()); let s = p!(TcpStream::connect("google.com:443")); builder.connect("goggle.com", s).unwrap(); } @@ -408,7 +437,7 @@ fn shutdown() { #[test] #[cfg(feature = "alpn")] fn alpn_google_h2() { - let builder = p!(TlsConnector::builder().request_alpns(&["h2"]).build()); + let builder = p!(connector!().request_alpns(&["h2"]).build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn()); @@ -418,7 +447,7 @@ fn alpn_google_h2() { #[test] #[cfg(feature = "alpn")] fn alpn_google_invalid() { - let builder = p!(TlsConnector::builder().request_alpns(&["h2c"]).build()); + let builder = p!(connector!().request_alpns(&["h2c"]).build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn()); @@ -428,7 +457,7 @@ fn alpn_google_invalid() { #[test] #[cfg(feature = "alpn")] fn alpn_google_none() { - let builder = p!(TlsConnector::new()); + let builder = p!(connector!().build()); let s = p!(TcpStream::connect("google.com:443")); let socket = p!(builder.connect("google.com", s)); let alpn = p!(socket.negotiated_alpn());