Skip to content

Commit

Permalink
UAPI: update device listen port (zarvd#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
zarvd authored Apr 11, 2023
1 parent 93417d7 commit ac69e1e
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 151 deletions.
24 changes: 9 additions & 15 deletions examples/wg-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::error::Error;
use std::time::Duration;

use base64::engine::general_purpose::STANDARD as base64Encoding;
use base64::Engine;
use tokio::signal::unix::{signal, SignalKind};
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

Expand Down Expand Up @@ -52,21 +52,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
uapi::bind_and_handle(handle).await.unwrap();
});

shutdown().await;
let handle = device.handle();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(10)).await;
info!("Updating listen port");
let _ = handle.update_listen_port(9991).await;
});

tokio::signal::ctrl_c().await?;
device.terminate().await; // stop gracefully

Ok(())
}

pub async fn shutdown() {
tokio::select! {
() = recv_signal_and_shutdown(SignalKind::interrupt()) => {}
() = recv_signal_and_shutdown(SignalKind::terminate()) => {}
};

info!("recv signal and shutting down");
}

async fn recv_signal_and_shutdown(kind: SignalKind) {
signal(kind).expect("register signal handler").recv().await;
}
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ fmt:

# Lint code with clippy
lint:
cargo clippy
cargo clippy --all-targets --all-features
136 changes: 84 additions & 52 deletions src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ pub use peer::{Cidr, ParseCidrError, PeerMetrics};

use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::{Arc, Mutex};
use std::time::Duration;

use futures::future::join_all;
use futures::StreamExt;
Expand All @@ -40,7 +39,7 @@ where
cfg: Mutex<DeviceConfig>,
rate_limiter: RateLimiter,
cookie: Cookie,
inbound: Inbound,
inbound: Mutex<Inbound>,
}

impl<T> Inner<T>
Expand All @@ -60,7 +59,7 @@ where

#[inline]
pub fn endpoint_for(&self, dst: SocketAddr) -> Endpoint {
self.inbound.endpoint_for(dst)
self.inbound.lock().unwrap().endpoint_for(dst)
}
}

Expand All @@ -86,8 +85,9 @@ where
T: Tun + 'static,
{
inner: Arc<Inner<T>>,
handles: Vec<JoinHandle<()>>,
cancel_token: CancellationToken,
inbound_handles: Arc<Mutex<(CancellationToken, Vec<JoinHandle<()>>)>>,
outbound_handles: Arc<Mutex<(CancellationToken, Vec<JoinHandle<()>>)>>,
token: CancellationToken,
}

#[cfg(feature = "native")]
Expand All @@ -103,14 +103,14 @@ where
T: Tun + 'static,
{
pub async fn with_tun(tun: T, mut cfg: DeviceConfig) -> Result<Self, Error> {
let cancel_token = CancellationToken::new();
let token = CancellationToken::new();

let inbound = Inbound::bind(cfg.listen_port).await?;
cfg.listen_port = inbound.local_port();

let secret = LocalStaticSecret::new(cfg.private_key);

let peers = PeerIndex::new(cancel_token.child_token(), tun.clone(), secret.clone());
let peers = PeerIndex::new(token.child_token(), tun.clone(), secret.clone());
cfg.peers.iter().for_each(|p| {
peers.insert(
p.public_key,
Expand All @@ -133,41 +133,53 @@ where
cfg,
cookie,
rate_limiter,
inbound,
inbound: Mutex::new(inbound),
})
};
let handles = vec![
tokio::spawn(loop_tun_events(Arc::clone(&inner), cancel_token.clone())),
tokio::spawn(loop_outbound(Arc::clone(&inner), cancel_token.clone())),
tokio::spawn(loop_inbound(
Arc::clone(&inner),
listener_v4,
cancel_token.clone(),
)),
tokio::spawn(loop_inbound(
Arc::clone(&inner),
listener_v6,
cancel_token.clone(),
)),
];

let outbound_handles = {
let token = token.child_token();
Arc::new(Mutex::new((
token.clone(),
vec![tokio::spawn(loop_outbound(Arc::clone(&inner), token))],
)))
};
let inbound_handles = {
let token = token.child_token();
Arc::new(Mutex::new((
token.clone(),
vec![
tokio::spawn(loop_inbound(Arc::clone(&inner), listener_v4, token.clone())),
tokio::spawn(loop_inbound(Arc::clone(&inner), listener_v6, token)),
],
)))
};

Ok(Device {
inner,
handles,
cancel_token,
inbound_handles,
outbound_handles,
token,
})
}

#[inline]
pub fn handle(&self) -> DeviceHandle<T> {
DeviceHandle {
token: self.token.clone(),
inner: Arc::clone(&self.inner),
inbound_handles: Arc::clone(&self.inbound_handles),
}
}

pub async fn terminate(mut self) {
self.cancel_token.cancel();
join_all(self.handles.drain(..)).await;
pub async fn terminate(self) {
self.token.cancel();

let mut handles = vec![];
handles.extend(&mut self.inbound_handles.lock().unwrap().1.drain(..));
handles.extend(&mut self.outbound_handles.lock().unwrap().1.drain(..));

join_all(handles).await;
}
}

Expand All @@ -176,7 +188,7 @@ where
T: Tun,
{
fn drop(&mut self) {
self.cancel_token.cancel();
self.token.cancel();
}
}

Expand Down Expand Up @@ -209,7 +221,9 @@ pub struct DeviceHandle<T>
where
T: Tun + 'static,
{
token: CancellationToken,
inner: Arc<Inner<T>>,
inbound_handles: Arc<Mutex<(CancellationToken, Vec<JoinHandle<()>>)>>,
}

impl<T> DeviceHandle<T>
Expand All @@ -234,6 +248,48 @@ where
self.inner.metrics()
}

pub async fn update_listen_port(&self, port: u16) -> Result<(), Error> {
{
let inbound = self.inner.inbound.lock().unwrap();
if inbound.local_port() == port {
return Ok(());
}
}

let new_inbound = Inbound::bind(port).await?;
let mut inbound = self.inner.inbound.lock().unwrap();
if inbound.local_port() == port {
return Ok(());
}
*inbound = new_inbound;

let v4 = inbound.v4();
let v6 = inbound.v6();

let mut handles = self.inbound_handles.lock().unwrap();
handles.0.cancel();

for peer in &self.inner.cfg.lock().unwrap().peers {
let pk = peer.public_key;
if let Some(peer) = self.inner.peers.get_by_key(&pk) {
if let Some(endpoint) = peer.endpoint() {
peer.update_endpoint(inbound.endpoint_for(endpoint.dst()));
}
}
}

let token = self.token.child_token();
*handles = (
token.clone(),
vec![
tokio::spawn(loop_inbound(Arc::clone(&self.inner), v4, token.clone())),
tokio::spawn(loop_inbound(Arc::clone(&self.inner), v6, token)),
],
);

Ok(())
}

/// Returns the configuration of a peer by its public key.
pub fn peer_config(&self, public_key: &[u8; 32]) -> Option<PeerConfig> {
self.inner
Expand Down Expand Up @@ -311,30 +367,6 @@ where
}
}

async fn loop_tun_events<T>(inner: Arc<Inner<T>>, token: CancellationToken)
where
T: Tun + 'static,
{
debug!("starting tun events loop");
loop {
tokio::select! {
_ = token.cancelled() => {
debug!("stopping tun events loop");
return;
}
_ = tick_tun_events(Arc::clone(&inner)) => {}
}
}
}

#[inline]
async fn tick_tun_events<T>(_inner: Arc<Inner<T>>)
where
T: Tun + 'static,
{
tokio::time::sleep(Duration::from_secs(5)).await;
}

async fn loop_inbound<T>(inner: Arc<Inner<T>>, mut listener: Listener, token: CancellationToken)
where
T: Tun + 'static,
Expand Down
11 changes: 11 additions & 0 deletions src/device/peer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ where
self.inner.update_endpoint(endpoint);
}

#[inline]
pub fn endpoint(&self) -> Option<Endpoint> {
self.inner.endpoint()
}

#[inline]
pub fn metrics(&self) -> PeerMetrics {
self.inner.monitor.metrics()
Expand Down Expand Up @@ -258,6 +263,12 @@ where
let _ = guard.insert(endpoint);
}

/// Return the endpoint of the peer.
#[inline]
pub fn endpoint(&self) -> Option<Endpoint> {
self.endpoint.read().unwrap().clone()
}

/// Send outbound data to the peer.
/// This method is called by the outbound loop and handshake loop.
#[inline]
Expand Down
4 changes: 3 additions & 1 deletion src/uapi/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Connection {
b"set=1\n" => {
let mut buf = vec![];
while self.reader.read_until(b'\n', &mut buf).await? > 1 {}
let s = unsafe { String::from_utf8_unchecked(buf).trim_end().to_string() };
let s = unsafe { String::from_utf8_unchecked(buf).trim_end().to_owned() };

Ok(Request::Set(parse_set_request(&s)?))
}
Expand All @@ -62,6 +62,7 @@ impl Connection {
}
}

#[allow(clippy::too_many_lines)]
fn parse_set_request(s: &str) -> Result<SetDevice, Error> {
debug!("UAPI: parsing set request: {:?}", s);

Expand Down Expand Up @@ -193,6 +194,7 @@ mod tests {
use super::*;

#[test]
#[allow(clippy::too_many_lines)]
fn test_parse_set_request() {
let rv = parse_set_request(
"private_key=e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a
Expand Down
2 changes: 2 additions & 0 deletions src/uapi/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
pub enum Error {
#[error("invalid protocol")]
InvalidProtocol,
#[error("invalid configuration: {0}")]
InvalidConfiguration(String),
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
}
Loading

0 comments on commit ac69e1e

Please sign in to comment.