Skip to content

Add PSK support #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,8 @@ required-features = ["std"]
name = "ssl_conf_verify"
path = "tests/ssl_conf_verify.rs"
required-features = ["std"]

[[test]]
name = "ssl_conf_psk"
path = "tests/ssl_conf_psk.rs"
required-features = ["std"]
45 changes: 43 additions & 2 deletions mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,49 @@ impl<'c> Config<'c> {
)
}
}

/// psk and psk_identity cannot be empty
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
assert!(psk_identity.len()>0);
assert!(psk.len()>0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mbedtls_ssl_conf_psk function already checks these arguments. I don't see a reason to check them here again

unsafe { ssl_conf_psk(&mut self.inner,
psk.as_ptr(), psk.len(),
psk_identity.as_ptr(), psk_identity.len())
.into_result().map(|_| ())
}
}

pub fn set_psk_callback<F>(&mut self, cb: &'c mut F)
where
F: FnMut(&mut HandshakeContext, &str) -> Result<()>,
{
unsafe extern "C" fn psk_callback<F>(
closure: *mut c_void,
ctx: *mut ssl_context,
psk_identity: *const c_uchar,
identity_len: size_t) -> c_int
where
F: FnMut(&mut HandshakeContext, &str) -> Result<()>,
{
assert!(identity_len>0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, I don't see a reason for the assert.

let cb = &mut *(closure as *mut F);
let mut ctx = UnsafeFrom::from(ctx).expect("valid context");
let psk_identity = &*(from_raw_parts(psk_identity, identity_len)
as *const [u8] as *const str);
match cb(&mut ctx, psk_identity) {
Ok(()) => 0,
Err(e) => e.to_int(),
}
}

unsafe {
ssl_conf_psk_cb(
&mut self.inner,
Some(psk_callback::<F>),
cb as *mut F as _
)
}
}
}

/// Builds a linked list of x509_crt instances, all of which are owned by mbedtls. That is, the
Expand Down Expand Up @@ -417,8 +460,6 @@ impl<'a> Iterator for KeyCertIter<'a> {
// ssl_conf_dtls_badmac_limit
// ssl_conf_handshake_timeout
// ssl_conf_session_cache
// ssl_conf_psk
// ssl_conf_psk_cb
// ssl_conf_sig_hashes
// ssl_conf_alpn_protocols
// ssl_conf_fallback
Expand Down
11 changes: 10 additions & 1 deletion mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ impl<'ctx> HandshakeContext<'ctx> {
.map(|_| ())
}
}

/// psk cannot be empty
pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> {
assert!(psk.len()>0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, I don't see a reason for the assert.

unsafe {
ssl_set_hs_psk(self.inner, psk.as_ptr(), psk.len())
.into_result()
.map(|_| ())
}
}
}

impl<'ctx> ::core::ops::Deref for HandshakeContext<'ctx> {
Expand Down Expand Up @@ -315,7 +325,6 @@ impl<'a> Drop for Session<'a> {
// ssl_renegotiate
// ssl_send_alert_message
// ssl_set_client_transport_id
// ssl_set_hs_psk
// ssl_set_timer_cb
//
// ssl_handshake_step
Expand Down
58 changes: 58 additions & 0 deletions mbedtls/tests/ssl_conf_psk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#![allow(dead_code)]
extern crate mbedtls;

use std::net::TcpStream;

mod support;
use support::entropy::entropy_new;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context, HandshakeContext};
use mbedtls::Result as TlsResult;


fn client(mut conn: TcpStream, psk: &[u8]) -> TlsResult<()> {
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: this doesn't need to be in it's own block

let mut entropy = entropy_new();
let mut rng = CtrDrbg::new(&mut entropy, None)?;
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(Some(&mut rng));
config.set_psk(psk, "Client_identity")?;
let mut ctx = Context::new(&config)?;
ctx.establish(&mut conn, None).map(|_| ())?;
Ok(())
}
}

fn server<F>(mut conn: TcpStream, mut psk_callback: F) -> TlsResult<()>
where
F: FnMut(&mut HandshakeContext, &str) -> TlsResult<()> {
let mut entropy = entropy_new();
let mut rng = CtrDrbg::new(&mut entropy, None)?;
let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default);
config.set_rng(Some(&mut rng));
config.set_psk_callback(&mut psk_callback);
let mut ctx = Context::new(&config)?;
let _ = ctx.establish(&mut conn, None)?;
Ok(())
}

#[cfg(unix)]
mod test {
use super::*;
use std::thread;
use crate::support::net::create_tcp_pair;
use crate::support::keys;

#[test]
fn callback_standard_psk() {
let (c, s) = create_tcp_pair().unwrap();
let psk_callback =
|ctx: &mut HandshakeContext, _: &str| { ctx.set_psk(keys::PRESHARED_KEY) };
let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap());
let s = thread::spawn(move || super::server(s, psk_callback).unwrap());
c.join().unwrap();
s.join().unwrap();
}
}
4 changes: 4 additions & 0 deletions mbedtls/tests/support/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,7 @@ C4j3yqL0Gbs+moaswS1UR8XSnKt8TBcXVozCAy12A4qKSjkP7VKPTLeTOZxw0UBe
8CzQYNKoIGy4ayFVi+VKaNCHKvJm0diQkKw5Tz7L5quBBjt8JpmRtNbPsjXiq4Is
y14Xc4kb05mM5M9u685eWefa
-----END PRIVATE KEY-----\0";

pub const PRESHARED_KEY: &'static [u8] = &[
234, 206, 151, 23, 219, 21, 71, 144,
107, 42, 23, 67, 249, 173, 182, 224 ];