Skip to content

Commit 693ccda

Browse files
elemountShuode Li
authored and
Shuode Li
committed
Support SSL in windows
1 parent e88ee0a commit 693ccda

File tree

6 files changed

+139
-23
lines changed

6 files changed

+139
-23
lines changed

Cargo.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ lto = true
4343
default = ['mysql_common']
4444
nightly = ['mysql_common']
4545
rustc_serialize = ['mysql_common/rustc_serialize', 'rustc-serialize']
46-
ssl = ['mysql_common', "openssl", "security-framework"]
46+
ssl = ['mysql_common', "openssl", "security-framework", "schannel"]
4747

4848
[dev-dependencies]
4949
serde_derive = "1"
@@ -79,6 +79,10 @@ version = "~0.2"
7979
optional = true
8080
features = ["OSX_10_9"]
8181

82+
[target.'cfg(target_os = "windows")'.dependencies.schannel]
83+
version = "~0.1"
84+
optional = true
85+
8286
[target.'cfg(target_os = "windows")'.dependencies]
8387
named_pipe = "~0.3"
8488
winapi = "~0.3"

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ features = ["rustc-serialize"]
4141
```
4242

4343
### Windows support (since 0.18.0)
44-
Windows is supported but currently rust-mysql-simple has no support for SSL on Windows.
44+
Windows is supported.

appveyor.yml

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ before_test:
4646
4747
$newText = ([System.IO.File]::ReadAllText($iniPath)).Replace("# enable-named-pipe", "enable-named-pipe")
4848
49+
$newText = $newText + "`nssl-ca=c:/clone/tests/ca-cert.pem`nssl-cert=c:/clone/tests/server-cert.pem`nssl-key=c:/clone/tests/server-key.pem"
50+
4951
[System.IO.File]::WriteAllText($iniPath, $newText)
5052
5153
Restart-Service MySQL57

src/conn/mod.rs

+13-9
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,12 @@ impl Conn {
779779
}
780780
}
781781

782-
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
782+
#[cfg(not(feature = "ssl"))]
783+
fn switch_to_ssl(&mut self) -> MyResult<()> {
784+
unimplemented!();
785+
}
786+
787+
#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
783788
fn switch_to_ssl(&mut self) -> MyResult<()> {
784789
match self.stream.take() {
785790
Some(ConnStream::Plain(stream)) => {
@@ -798,11 +803,6 @@ impl Conn {
798803
Ok(())
799804
}
800805

801-
#[cfg(any(not(feature = "ssl"), target_os = "windows"))]
802-
fn switch_to_ssl(&mut self) -> MyResult<()> {
803-
unimplemented!();
804-
}
805-
806806
fn connect_stream(&mut self) -> MyResult<()> {
807807
let read_timeout = self.opts.get_read_timeout().cloned();
808808
let write_timeout = self.opts.get_write_timeout().cloned();
@@ -2075,7 +2075,11 @@ mod test {
20752075
builder.into()
20762076
}
20772077

2078-
#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
2078+
#[cfg(all(
2079+
feature = "ssl",
2080+
not(target_os = "macos"),
2081+
any(unix, target_os = "windows")
2082+
))]
20792083
pub fn get_opts() -> Opts {
20802084
let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string());
20812085
let port: u16 = env::var("MYSQL_SERVER_PORT")
@@ -2099,7 +2103,7 @@ mod test {
20992103
builder.into()
21002104
}
21012105

2102-
#[cfg(any(not(feature = "ssl"), target_os = "windows"))]
2106+
#[cfg(not(feature = "ssl"))]
21032107
pub fn get_opts() -> Opts {
21042108
let pwd: String = env::var("MYSQL_SERVER_PASS").unwrap_or(PASS.to_string());
21052109
let port: u16 = env::var("MYSQL_SERVER_PORT")
@@ -2556,7 +2560,7 @@ mod test {
25562560
}
25572561

25582562
#[test]
2559-
#[cfg(all(feature = "ssl", any(target_os = "macos", unix)))]
2563+
#[cfg(all(feature = "ssl", any(target_os = "macos", target_os = "windows", unix)))]
25602564
fn should_connect_via_ssl() {
25612565
let mut opts = OptsBuilder::from_opts(get_opts());
25622566
opts.prefer_socket(false);

src/conn/opts.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::consts::CapabilityFlags;
22

33
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
4-
#[cfg(all(feature = "ssl", not(target_os = "windows")))]
4+
#[cfg(all(feature = "ssl"))]
55
use std::path;
66
use std::str::FromStr;
77

@@ -29,8 +29,8 @@ pub type SslOpts = Option<Option<(path::PathBuf, String, Vec<path::PathBuf>)>>;
2929
pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>;
3030

3131
#[cfg(all(feature = "ssl", target_os = "windows"))]
32-
/// Not implemented on Windows
33-
pub type SslOpts = Option<()>;
32+
/// Ssl options: Option<(pem_ca_cert, Option<(pem_client_cert, pem_client_key)>)>.`
33+
pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>;
3434

3535
#[cfg(not(feature = "ssl"))]
3636
/// Requires `ssl` feature
@@ -445,7 +445,11 @@ impl OptsBuilder {
445445
self
446446
}
447447

448-
#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
448+
#[cfg(all(
449+
feature = "ssl",
450+
not(target_os = "macos"),
451+
any(unix, target_os = "windows")
452+
))]
449453
/// SSL certificates and keys in pem format.
450454
///
451455
/// If not None, then ssl connection implied.
@@ -487,11 +491,6 @@ impl OptsBuilder {
487491
self
488492
}
489493

490-
/// Not implemented on windows
491-
#[cfg(all(feature = "ssl", target_os = "windows"))]
492-
pub fn ssl_opts<A, B, C>(&mut self, _: Option<SslOpts>) -> &mut Self {
493-
panic!("OptsBuilder::ssl_opts is not implemented on Windows");
494-
}
495494

496495
/// Requires `ssl` feature
497496
#[cfg(not(feature = "ssl"))]

src/io/mod.rs

+110-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::net::SocketAddr;
99
use std::slice::Chunks;
1010
use std::time::Duration;
1111

12-
#[cfg(all(feature = "ssl", not(target_os = "windows")))]
12+
#[cfg(all(feature = "ssl"))]
1313
use crate::conn::SslOpts;
1414

1515
use super::consts;
@@ -32,6 +32,14 @@ use flate2::{read::ZlibDecoder, write::ZlibEncoder, Compression};
3232
use named_pipe as np;
3333
#[cfg(all(feature = "ssl", all(unix, not(target_os = "macos"))))]
3434
use openssl::ssl::{self, SslContext, SslStream};
35+
#[cfg(all(feature = "ssl", target_os = "windows"))]
36+
use schannel::cert_context::CertContext;
37+
#[cfg(all(feature = "ssl", target_os = "windows"))]
38+
use schannel::cert_store;
39+
#[cfg(all(feature = "ssl", target_os = "windows"))]
40+
use schannel::schannel_cred;
41+
#[cfg(all(feature = "ssl", target_os = "windows"))]
42+
use schannel::tls_stream;
3543
#[cfg(all(feature = "ssl", target_os = "macos"))]
3644
use security_framework::certificate::SecCertificate;
3745
#[cfg(all(feature = "ssl", target_os = "macos"))]
@@ -763,6 +771,103 @@ impl Stream {
763771
}
764772
}
765773

774+
#[cfg(all(feature = "ssl", target_os = "windows"))]
775+
impl Stream {
776+
pub fn make_secure(
777+
mut self,
778+
verify_peer: bool,
779+
ip_or_hostname: Option<&str>,
780+
ssl_opts: &SslOpts,
781+
) -> MyResult<Stream> {
782+
use std::path::Path;
783+
784+
fn load_cert_data(path: &Path) -> MyResult<String> {
785+
let mut client_file = ::std::fs::File::open(path)?;
786+
let mut client_data = String::new();
787+
client_file.read_to_string(&mut client_data)?;
788+
Ok(client_data)
789+
}
790+
791+
fn load_client_cert(path: &Path) -> MyResult<CertContext> {
792+
let cert_data = load_cert_data(path)?;
793+
let cert = CertContext::from_pem(&cert_data)?;
794+
Ok(cert)
795+
}
796+
797+
fn load_client_cert_with_key(cert_path: &Path, key_path: &Path) -> MyResult<CertContext> {
798+
let mut cert_data = load_cert_data(cert_path)?;
799+
let cert = CertContext::from_pem(&cert_data)?;
800+
let key_data = load_cert_data(key_path)?;
801+
cert_data.push_str(&key_data);
802+
Ok(cert)
803+
}
804+
805+
fn load_ca_store(path: &Path) -> MyResult<cert_store::CertStore> {
806+
let ca_cert = load_client_cert(path)?;
807+
let mut cert_store = cert_store::Memory::new().unwrap().into_store();
808+
cert_store.add_cert(&ca_cert, cert_store::CertAdd::Always)?;
809+
Ok(cert_store)
810+
}
811+
812+
if self.is_insecure() {
813+
let mut stream_builder = tls_stream::Builder::new();
814+
let mut cred_builder = schannel_cred::Builder::default();
815+
cred_builder.enabled_protocols(&[
816+
schannel_cred::Protocol::Tls10,
817+
schannel_cred::Protocol::Tls11,
818+
]);
819+
cred_builder.supported_algorithms(&[
820+
schannel_cred::Algorithm::DhEphem,
821+
schannel_cred::Algorithm::RsaSign,
822+
schannel_cred::Algorithm::Aes256,
823+
schannel_cred::Algorithm::Sha1,
824+
]);
825+
if verify_peer {
826+
stream_builder.domain(ip_or_hostname.as_ref().unwrap_or(&("localhost".into())));
827+
}
828+
829+
match *ssl_opts {
830+
Some((ref ca_cert, None)) => {
831+
stream_builder.cert_store(load_ca_store(&ca_cert)?);
832+
}
833+
Some((ref ca_cert, Some((ref client_cert, ref client_key)))) => {
834+
cred_builder.cert(load_client_cert_with_key(&client_cert, &client_key)?);
835+
stream_builder.cert_store(load_ca_store(&ca_cert)?);
836+
}
837+
_ => unreachable!(),
838+
}
839+
840+
let cred = cred_builder.acquire(schannel_cred::Direction::Outbound)?;
841+
match self {
842+
Stream::TcpStream(ref mut opt_stream) if opt_stream.is_some() => {
843+
let stream = opt_stream.take().unwrap();
844+
match stream {
845+
TcpStream::Insecure(mut stream) => {
846+
stream.flush()?;
847+
let s_stream = match stream_builder
848+
.connect(cred, stream.into_inner().unwrap())
849+
{
850+
Ok(s_stream) => s_stream,
851+
Err(tls_stream::HandshakeError::Failure(err)) => {
852+
return Err(err.into());
853+
}
854+
Err(tls_stream::HandshakeError::Interrupted(_)) => unreachable!(),
855+
};
856+
Ok(Stream::TcpStream(Some(TcpStream::Secure(BufStream::new(
857+
s_stream,
858+
)))))
859+
}
860+
_ => unreachable!(),
861+
}
862+
}
863+
_ => unreachable!(),
864+
}
865+
} else {
866+
Ok(self)
867+
}
868+
}
869+
}
870+
766871
#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
767872
impl Stream {
768873
pub fn make_secure(
@@ -838,13 +943,15 @@ impl Drop for Stream {
838943
pub enum TcpStream {
839944
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
840945
Secure(BufStream<SslStream<net::TcpStream>>),
946+
#[cfg(all(feature = "ssl", target_os = "windows"))]
947+
Secure(BufStream<tls_stream::TlsStream<net::TcpStream>>),
841948
Insecure(BufStream<net::TcpStream>),
842949
}
843950

844951
impl AsMut<dyn IoPack> for TcpStream {
845952
fn as_mut(&mut self) -> &mut dyn IoPack {
846953
match *self {
847-
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
954+
#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
848955
TcpStream::Secure(ref mut stream) => stream,
849956
TcpStream::Insecure(ref mut stream) => stream,
850957
}
@@ -854,7 +961,7 @@ impl AsMut<dyn IoPack> for TcpStream {
854961
impl fmt::Debug for TcpStream {
855962
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
856963
match *self {
857-
#[cfg(all(feature = "ssl", any(unix, target_os = "macos")))]
964+
#[cfg(all(feature = "ssl", any(unix, target_os = "macos", target_os = "windows")))]
858965
TcpStream::Secure(_) => write!(f, "Secure stream"),
859966
TcpStream::Insecure(ref s) => write!(f, "Insecure stream {:?}", s),
860967
}

0 commit comments

Comments
 (0)