Skip to content

Commit

Permalink
Implement -R tunnels in the example client
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Aug 6, 2022
1 parent 6008e51 commit 1f33690
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 25 deletions.
116 changes: 91 additions & 25 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use anyhow::{Result, Context as _, bail};
use bytes::BytesMut;
use enclose::enclose;
use futures::ready;
use futures::future::{FutureExt as _, FusedFuture as _};
use futures::future::{FutureExt as _, FusedFuture as _, Fuse};
use futures::stream::{StreamExt as _, TryStreamExt as _, FuturesUnordered};
use guard::guard;
use regex::Regex;
use rustix::termios;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::{env, fs};
use std::future::Future;
use std::os::unix::io::AsRawFd as _;
Expand Down Expand Up @@ -101,7 +102,6 @@ struct Opts {
command: Option<String>,
want_tty: bool,
local_tunnels: Vec<TunnelSpec>,
#[allow(dead_code)]
remote_tunnels: Vec<TunnelSpec>,
}

Expand Down Expand Up @@ -188,23 +188,28 @@ async fn run_client(opts: Opts) -> Result<ExitCode> {
.context("could not open TCP connection to the server")?;
log::info!("successfully connected");

let (client, mut client_rx, client_fut) = makiko::Client::open(socket, config)?;
let remote_tunnel_addrs = opts.remote_tunnels.into_iter()
.map(|spec| {
let bind_addr = (spec.bind_host.unwrap_or("".into()), spec.bind_port);
let connect_addr = (spec.connect_host, spec.connect_port);
(bind_addr, connect_addr)
})
.collect::<HashMap<_, _>>();

let (client, client_rx, client_fut) = makiko::Client::open(socket, config)?;
let client_task = TaskHandle(tokio::task::spawn(client_fut));

let event_task = TaskHandle(tokio::task::spawn(enclose!{(client) async move {
while let Some(event) = client_rx.recv().await? {
if let makiko::ClientEvent::ServerPubkey(pubkey, accept_tx) = event {
verify_pubkey(&client, pubkey, accept_tx).await?;
}
}
Result::<()>::Ok(())
}}));
let event_task = TaskHandle(tokio::task::spawn(
run_events(client.clone(), client_rx, remote_tunnel_addrs.clone())
));

let interact_task = TaskHandle(tokio::task::spawn(enclose!{(client) async move {
authenticate(&client, username, opts.keys).await
.context("could not authenticate")?;
log::info!("successfully authenticated");

bind_remote_tunnels(&client, &remote_tunnel_addrs).await?;

let session_task = TaskHandle(tokio::task::spawn(enclose!{(client) async move {
run_session(client, opts.command, opts.want_tty).await
}}));
Expand Down Expand Up @@ -245,8 +250,40 @@ async fn run_client(opts: Opts) -> Result<ExitCode> {
}
}

async fn run_events(
client: makiko::Client,
mut client_rx: makiko::ClientReceiver,
remote_tunnel_addrs: HashMap<(String, u16), (String, u16)>,
) -> Result<()> {
let mut pubkey_task = Fuse::terminated();
let mut tunnel_tasks = FuturesUnordered::new();
loop {
tokio::select!{
event = client_rx.recv() => match event? {
Some(makiko::ClientEvent::ServerPubkey(pubkey, accept_tx)) => {
pubkey_task = TaskHandle(tokio::task::spawn(
verify_pubkey(client.clone(), pubkey, accept_tx)
)).fuse();
},
Some(makiko::ClientEvent::Tunnel(accept)) => {
let connect_addr = remote_tunnel_addrs.get(&accept.connected_addr);
guard!{let Some(connect_addr) = connect_addr else { continue }};
tunnel_tasks.push(TaskHandle(tokio::task::spawn(
run_remote_tunnel(accept, connect_addr.clone())
)));
},
Some(_) => continue,
None => break,
},
res = &mut pubkey_task => res?,
Some(res) = tunnel_tasks.next() => res?,
};
}
Ok(())
}

async fn verify_pubkey(
client: &makiko::Client,
client: makiko::Client,
pubkey: makiko::Pubkey,
accept_tx: makiko::AcceptPubkey,
) -> Result<()> {
Expand All @@ -265,6 +302,35 @@ async fn verify_pubkey(
Ok(())
}

async fn run_remote_tunnel(accept: makiko::AcceptTunnel, connect_addr: (String, u16)) -> Result<()> {
match tokio::net::TcpStream::connect(&connect_addr).await {
Ok(socket) => {
let config = makiko::ChannelConfig::default();
let (tunnel, tunnel_rx) = accept.accept(config).await?;
run_tunnel_socket(tunnel, tunnel_rx, socket).await
},
Err(err) => {
log::warn!("Could not open tunnel to {:?}: {}", connect_addr, err);
accept.reject(makiko::ChannelOpenError {
reason_code: makiko::codes::open::CONNECT_FAILED,
description: format!("Connect attempt failed: {}", err),
description_lang: "".into(),
});
Ok(())
},
}
}

async fn bind_remote_tunnels(
client: &makiko::Client,
remote_tunnel_addrs: &HashMap<(String, u16), (String, u16)>,
) -> Result<()> {
for bind_addr in remote_tunnel_addrs.keys() {
client.bind_tunnel(bind_addr.clone())?.wait().await?;
}
Ok(())
}

async fn authenticate(client: &makiko::Client, username: String, keys: Vec<Key>) -> Result<()> {
struct AuthCtx<'c> {
client: &'c makiko::Client,
Expand Down Expand Up @@ -488,29 +554,29 @@ async fn run_local_tunnel(client: makiko::Client, spec: TunnelSpec) -> Result<()
tokio::select!{
res = listener.accept() => {
let (socket, peer_addr) = res?;

let config = makiko::ChannelConfig::default();
let connect_addr = (spec.connect_host.clone(), spec.connect_port);
let originator_addr = (peer_addr.ip().to_string(), peer_addr.port());
let task = TaskHandle(tokio::task::spawn(
run_local_tunnel_socket(client.clone(), connect_addr, socket, originator_addr)
));
let (tunnel, tunnel_rx) = client.connect_tunnel(
config, connect_addr, originator_addr).await?;

let task = TaskHandle(tokio::task::spawn(run_tunnel_socket(tunnel, tunnel_rx, socket)));
socket_tasks.push(task);
},
Some(res) = socket_tasks.next() => res?,
}
}
}

async fn run_local_tunnel_socket(
client: makiko::Client,
connect_addr: (String, u16),
async fn run_tunnel_socket(
tunnel: makiko::Tunnel,
mut tunnel_rx: makiko::TunnelReceiver,
socket: tokio::net::TcpStream,
originator_addr: (String, u16),
) -> Result<()> {
let config = makiko::ChannelConfig::default();
let (tunnel, mut tunnel_rx) = client.connect_tunnel(config, connect_addr, originator_addr).await?;
let (mut socket_read, mut socket_write) = socket.into_split();

let socket_to_client = TaskHandle(tokio::task::spawn(async move {
let socket_to_tunnel = TaskHandle(tokio::task::spawn(async move {
let mut buffer = BytesMut::new();
while socket_read.read_buf(&mut buffer).await? != 0 {
tunnel.send_data(buffer.split().freeze()).await?;
Expand All @@ -519,7 +585,7 @@ async fn run_local_tunnel_socket(
Result::<_>::Ok(())
}));

let client_to_socket = TaskHandle(tokio::task::spawn(async move {
let tunnel_to_socket = TaskHandle(tokio::task::spawn(async move {
while let Some(event) = tunnel_rx.recv().await? {
match event {
makiko::TunnelEvent::Data(mut data) =>
Expand All @@ -532,7 +598,7 @@ async fn run_local_tunnel_socket(
Result::<_>::Ok(())
}));

tokio::try_join!(socket_to_client, client_to_socket)?;
tokio::try_join!(socket_to_tunnel, tunnel_to_socket)?;
Ok(())
}

Expand Down
4 changes: 4 additions & 0 deletions src/client/client_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,12 @@ impl AcceptChannel {
#[derive(Debug)]
pub struct AcceptTunnel {
accept: AcceptChannel,

/// The address on the SSH server that the remote peer has connected to.
///
/// This should be equal to the address that you have passed to [`Client::bind_tunnel()`].
pub connected_addr: (String, u16),

/// The address of the remote peer.
pub originator_addr: (String, u16),
}
Expand Down

0 comments on commit 1f33690

Please sign in to comment.