Skip to content

Commit

Permalink
Tracing crashes (#52)
Browse files Browse the repository at this point in the history
* Key can be converted from base16 string; other debugging improvements

* Format and update Cargo.lock

* Remove unused code

* Rename KernelNotSupported to PlatformNotSupported

* Revert KernelNotSupported

* Bump version
  • Loading branch information
moubctez authored Feb 5, 2024
1 parent a118bb7 commit d253efc
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 68 deletions.
30 changes: 15 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "defguard_wireguard_rs"
version = "0.4.0"
version = "0.4.1"
edition = "2021"
description = "A unified multi-platform high-level API for managing WireGuard interfaces"
license = "Apache-2.0"
Expand Down
10 changes: 5 additions & 5 deletions examples/userspace.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(target_os = "mac_os")]
use defguard_wireguard_rs::WireguardApiUserspace;
use defguard_wireguard_rs::{host::Peer, key::Key, net::IpAddrMask, InterfaceConfiguration};
#[cfg(target_os = "macos")]
use defguard_wireguard_rs::{WireguardApiUserspace, WireguardInterfaceApi};
use std::{
io::{stdin, stdout, Read, Write},
net::SocketAddr,
Expand All @@ -15,7 +15,7 @@ fn pause() {
stdin().read_exact(&mut [0]).unwrap();
}

#[cfg(target_os = "mac_os")]
#[cfg(target_os = "macos")]
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Setup API struct for interface management
let ifname: String = if cfg!(target_os = "linux") || cfg!(target_os = "freebsd") {
Expand All @@ -35,7 +35,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let peer_key: Key = key.as_ref().try_into().unwrap();
let mut peer = Peer::new(peer_key.clone());

log::info!("endpoint");
println!("endpoint");
// Your WireGuard server endpoint which peer connects too
let endpoint: SocketAddr = "10.20.30.40:55001".parse().unwrap();
// Peer endpoint and interval
Expand Down Expand Up @@ -73,5 +73,5 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[cfg(not(mac_os))]
#[cfg(not(target_os = "macos"))]
fn main() {}
6 changes: 5 additions & 1 deletion src/bsd/ifconfig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ impl IfReq {
.for_each(|(i, b)| ifr_name[i] = b);

// First, try to load a kernel module for this type of network interface.
let mod_name = format!("if_{if_name}");
// Omit digits at the end of interface name, e.g. "wg0" -> "if_wg".
let index = if_name
.find(|c: char| c.is_ascii_digit())
.unwrap_or(if_name.len());
let mod_name = format!("if_{}", &if_name[0..index]);
unsafe {
// Ignore the return value for the time being.
// Do the cast because `c_char` differs across platforms.
Expand Down
15 changes: 12 additions & 3 deletions src/bsd/wgio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ impl WgDataIo {
let socket = create_socket(AddressFamily::Unix).map_err(IoError::ReadIo)?;
unsafe {
// First do ioctl with empty `wg_data` to obtain buffer size.
read_wireguard_data(socket.as_raw_fd(), self).map_err(IoError::ReadIo)?;
if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo first read error {err}");
return Err(IoError::ReadIo(err));
}
// Allocate buffer.
self.alloc_data()?;
// Second call to ioctl with allocated buffer.
read_wireguard_data(socket.as_raw_fd(), self).map_err(IoError::ReadIo)?;
if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo second read error {err}");
return Err(IoError::ReadIo(err));
}
}

Ok(())
Expand All @@ -75,7 +81,10 @@ impl WgDataIo {
pub(super) fn write_data(&mut self) -> Result<(), IoError> {
let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?;
unsafe {
write_wireguard_data(socket.as_raw_fd(), self).map_err(IoError::WriteIo)?;
if let Err(err) = write_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo write error {err}");
return Err(IoError::WriteIo(err));
}
}

Ok(())
Expand Down
75 changes: 34 additions & 41 deletions src/key.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Public key utilities
use std::{
error, fmt,
fmt,
hash::{Hash, Hasher},
str::FromStr,
};
Expand All @@ -11,29 +11,21 @@ use serde::{Deserialize, Serialize};

const KEY_LENGTH: usize = 32;

/// Returns value of hex digit, if possible.
fn hex_value(char: u8) -> Option<u8> {
match char {
b'A'..=b'F' => Some(char - b'A' + 10),
b'a'..=b'f' => Some(char - b'a' + 10),
b'0'..=b'9' => Some(char - b'0'),
_ => None,
}
}

/// WireGuard key representation in binary form.
#[derive(Clone, Default, Serialize, Deserialize)]
#[serde(try_from = "&str")]
pub struct Key([u8; KEY_LENGTH]);

#[derive(Debug)]
pub enum KeyError {
InvalidCharacter(u8),
InvalidStringLength(usize),
}

impl error::Error for KeyError {}

impl fmt::Display for KeyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::InvalidCharacter(char) => {
write!(f, "Invalid character {char}")
}
Self::InvalidStringLength(length) => write!(f, "Invalid string length {length}"),
}
}
}

impl Key {
/// Create a new key from buffer.
#[must_use]
Expand Down Expand Up @@ -71,28 +63,22 @@ impl Key {
/// Converts a text string of hexadecimal digits to `Key`.
///
/// # Errors
/// Will return `KeyError` if text string has wrong length,
/// Will return `DecodeError` if text string has wrong length,
/// or contains an invalid character.
pub fn decode<T: AsRef<[u8]>>(hex: T) -> Result<Self, KeyError> {
pub fn decode<T: AsRef<[u8]>>(hex: T) -> Result<Self, DecodeError> {
let hex = hex.as_ref();
let length = hex.len();
if length != 64 {
return Err(KeyError::InvalidStringLength(length));
if hex.len() != KEY_LENGTH * 2 {
return Err(DecodeError::InvalidLength);
}

let hex_value = |char: u8| -> Result<u8, KeyError> {
match char {
b'A'..=b'F' => Ok(char - b'A' + 10),
b'a'..=b'f' => Ok(char - b'a' + 10),
b'0'..=b'9' => Ok(char - b'0'),
_ => Err(KeyError::InvalidCharacter(char)),
}
};

let mut key = [0; KEY_LENGTH];
for (index, chunk) in hex.chunks(2).enumerate() {
let msd = hex_value(chunk[0])?;
let lsd = hex_value(chunk[1])?;
let Some(msd) = hex_value(chunk[0]) else {
return Err(DecodeError::InvalidByte(index, chunk[0]));
};
let Some(lsd) = hex_value(chunk[1]) else {
return Err(DecodeError::InvalidByte(index, chunk[1]));
};
key[index] = msd << 4 | lsd;
}
Ok(Self(key))
Expand All @@ -102,13 +88,20 @@ impl Key {
impl TryFrom<&str> for Key {
type Error = DecodeError;

/// Try to decode `Key` from base16 or base64 encoded string.
fn try_from(value: &str) -> Result<Self, Self::Error> {
let v = BASE64_STANDARD.decode(value)?;
if v.len() == KEY_LENGTH {
let buf = v.try_into().map_err(|_| Self::Error::InvalidLength)?;
Ok(Self::new(buf))
if value.len() == KEY_LENGTH * 2 {
// Try base16
Key::decode(value)
} else {
Err(Self::Error::InvalidLength)
// Try base64
let v = BASE64_STANDARD.decode(value)?;
if v.len() == KEY_LENGTH {
let buf = v.try_into().map_err(|_| Self::Error::InvalidLength)?;
Ok(Self::new(buf))
} else {
Err(Self::Error::InvalidLength)
}
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/wgapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ impl WGApi {
#[cfg(target_os = "freebsd")]
return Ok(Self(Box::new(WireguardApiFreebsd::new(ifname))));

#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
#[cfg(not(any(
target_os = "linux",
target_os = "freebsd",
target_os = "macos",
target_os = "windows"
)))]
Err(WireguardInterfaceError::KernelNotSupported)
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/wgapi_freebsd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{

/// Manages interfaces created with FreeBSD kernel WireGuard module.
///
/// Requires FreeBSD version 14+.
/// Requires FreeBSD version 13+.
#[derive(Clone)]
pub struct WireguardApiFreebsd {
ifname: String,
Expand Down

0 comments on commit d253efc

Please sign in to comment.