Skip to content

Commit

Permalink
Implement local tunnels
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Jul 18, 2022
1 parent 472aeea commit 758f817
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 19 deletions.
29 changes: 29 additions & 0 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use parking_lot::Mutex;
use pin_project::pin_project;
use rand::rngs::OsRng;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
Expand All @@ -23,6 +24,7 @@ use super::client_event::ClientEvent;
use super::client_state::{self, ClientState};
use super::conn::{self, OpenChannel};
use super::session::{Session, SessionReceiver};
use super::tunnel::{Tunnel, TunnelReceiver};

/// Handle to an SSH connection.
///
Expand Down Expand Up @@ -195,6 +197,33 @@ impl Client {
Session::open(self, config).await
}

/// Opens a tunnel by asking the server to connect to a host ("local forwarding").
///
/// If the server accepts the request, it will try to connect to a host and port determined by
/// `connect_addr`. The host may be either an IP address or a domain name. You should also
/// specify the `originator_addr`, which should be the IP address and port of the machine from
/// where the connection request originates.
///
/// If the tunnel is opened successfully, you receive two objects:
///
/// - [`Tunnel`] is the handle for sending data to the server.
/// - [`TunnelReceiver`] receives the data from the server as
/// [`TunnelEvent`][super::TunnelEvent]s. You **must** receive these events in time, otherwise
/// the client will stall.
///
/// You can open many tunnels or sessions in parallel, the SSH protocol will multiplex them
/// over the underlying connection.
///
/// This method will wait until you are authenticated before doing anything.
pub async fn connect_tunnel(
&self,
config: ChannelConfig,
connect_addr: (String, u16),
originator_addr: (IpAddr, u16),
) -> Result<(Tunnel, TunnelReceiver)> {
Tunnel::connect(self, config, connect_addr, originator_addr).await
}

/// Opens a raw SSH channel (low level API).
///
/// Use this to directly open an SSH channel, as described in RFC 4254, section 5.
Expand Down
2 changes: 2 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use self::session::{
Session, SessionReceiver, SessionEvent, SessionReply, ExitSignal,
PtyRequest, PtyTerminalModes, WindowChange,
};
pub use self::tunnel::{Tunnel, TunnelReceiver, TunnelEvent};

#[macro_use] mod pump;
mod auth;
Expand All @@ -26,3 +27,4 @@ mod ext;
mod negotiate;
mod recv;
mod session;
mod tunnel;
1 change: 1 addition & 0 deletions src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ pub struct ExitSignal {
/// by the server on the channel. You can ignore these events if you don't need them, but you
/// **must** receive them, otherwise the client will stall when the internal buffer of events fills
/// up.
#[derive(Debug)]
pub struct SessionReceiver {
channel_rx: ChannelReceiver,
}
Expand Down
131 changes: 131 additions & 0 deletions src/client/tunnel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use bytes::Bytes;
use futures_core::ready;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::codec::PacketEncode;
use crate::error::Result;
use super::channel::{Channel, ChannelReceiver, ChannelEvent, ChannelConfig, DATA_STANDARD};
use super::client::Client;

/// Handle to an SSH tunnel (TCP/IP forwarding channel).
///
/// TCP/IP forwarding channels (RFC 4253, section 7), commonly called "tunnels", allow you to
/// transmit ordinary TCP/IP sockets over SSH. There are two ways how to obtain a tunnel:
///
/// - You can ask the server to connect to an address using [`Client::connect_tunnel()`]. This is
/// sometimes called "local forwarding".
/// - You can ask the server to bind to an address and listen for incoming connections. This is
/// sometimes called "remote forwarding" and is not yet implemented.
#[derive(Clone)]
pub struct Tunnel {
channel: Channel,
}

impl Tunnel {
pub(super) async fn connect(
client: &Client,
config: ChannelConfig,
connect_addr: (String, u16),
originator_addr: (IpAddr, u16),
) -> Result<(Tunnel, TunnelReceiver)> {
let mut open_payload = PacketEncode::new();
open_payload.put_str(&connect_addr.0);
open_payload.put_u32(connect_addr.1 as u32);
open_payload.put_str(&originator_addr.0.to_string());
open_payload.put_u32(originator_addr.1 as u32);

let (channel, channel_rx, _) = client.open_channel(
"direct-tcpip".into(), config, open_payload.finish()).await?;
Ok((Tunnel { channel }, TunnelReceiver { channel_rx }))
}
}

impl Tunnel {
/// Send data to the tunnel.
///
/// This method returns after all bytes have been accepted by the flow control mechanism and
/// written to the internal send buffer, but before we send them to the socket (or other I/O
/// stream that backs this SSH connection).
pub async fn send_data(&self, data: Bytes) -> Result<()> {
self.channel.send_data(data, DATA_STANDARD).await
}

/// Signals that no more data will be sent to this channel.
///
/// This method returns after all bytes previously sent to this tunnel have been accepted by
/// the flow control mechanism, but before we write the message to the socket (or other I/O
/// stream that backs this SSH connection).
///
/// If the tunnel is closed before you call this method, or if it closes before this method
/// returns, we quietly ignore this error and return `Ok`.
pub async fn send_eof(&self) -> Result<()> {
self.channel.send_eof().await
}
}

/// Receiving half of a [`Tunnel`].
///
/// [`TunnelReceiver`] produces [`TunnelEvent`]s, which correspond to the data sent by the server
/// on the tunnel. You can ignore these events if you don't need them, but you **must** receive
/// them, otherwise the client will stall when the internal buffer of events fills up.
#[derive(Debug)]
pub struct TunnelReceiver {
channel_rx: ChannelReceiver,
}

/// An event returned from [`TunnelReceiver`].
///
/// These are events related to a particular SSH tunnel, they correspond to the data sent by the
/// server.
///
/// This enum is marked as `#[non_exhaustive]`, so that we can add new variants without breaking
/// backwards compatibility. It should always be safe to ignore any events that you don't intend to
/// handle.
#[derive(Debug)]
#[non_exhaustive]
pub enum TunnelEvent {
/// Data received from the tunnel.
///
/// You should handle this data as a byte stream, the boundaries between consecutive `Data`
/// events might be arbitrary.
Data(Bytes),

/// End of file received from the tunnel.
///
/// After this, we should not receive more data from the tunnel, but the tunnel is not yet
/// closed.
Eof,
}

impl TunnelReceiver {
/// Receive data from the tunnel.
///
/// Returns `None` if the tunnel was closed.
pub async fn recv(&mut self) -> Result<Option<TunnelEvent>> {
struct Recv<'a> { rx: &'a mut TunnelReceiver }
impl<'a> Future for Recv<'a> {
type Output = Result<Option<TunnelEvent>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.rx.poll_recv(cx)
}
}
Recv { rx: self }.await
}

/// Poll-friendly variant of [`.recv()`][Self::recv()].
pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<Result<Option<TunnelEvent>>> {
loop {
match ready!(self.channel_rx.poll_recv(cx)) {
Some(ChannelEvent::Data(data, DATA_STANDARD)) =>
return Poll::Ready(Ok(Some(TunnelEvent::Data(data)))),
Some(ChannelEvent::Eof) =>
return Poll::Ready(Ok(Some(TunnelEvent::Eof))),
Some(ChannelEvent::Data(_, _) | ChannelEvent::Request(_)) =>
continue,
None => return Poll::Ready(Ok(None)),
}
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use crate::client::{
Session, SessionReceiver, SessionEvent, SessionReply, ExitSignal,
PtyRequest, PtyTerminalModes, WindowChange,
};
pub use crate::client::{Tunnel, TunnelReceiver, TunnelEvent};
pub use crate::codec::{PacketEncode, PacketDecode};
pub use crate::error::{Result, Error, AlgoNegotiateError, DisconnectError, ChannelOpenError};

Expand Down
2 changes: 2 additions & 0 deletions tests/compat/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod nursery;
mod session_test;
mod smoke_test;
mod ssh_server;
mod tunnel_test;

#[path = "../keys/keys.rs"]
#[allow(dead_code)]
Expand Down Expand Up @@ -152,6 +153,7 @@ async fn run_all_tests(selector: TestSelector) -> Result<TestResult> {
smoke_test::collect(&mut suite);
auth_test::collect(&mut suite);
session_test::collect(&mut suite);
tunnel_test::collect(&mut suite);

let mut ctx = TestCtx { docker, selector, suite, result: TestResult::default() };
for server_name in server_names.into_iter() {
Expand Down
1 change: 1 addition & 0 deletions tests/compat/nursery.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![allow(dead_code)]
use futures::stream::{FuturesUnordered, Stream, StreamExt as _};
use std::future::Future;
use std::panic::resume_unwind;
Expand Down
13 changes: 3 additions & 10 deletions tests/compat/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use std::future::Future;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use crate::{TestSuite, TestCase, keys};
use crate::{TestSuite, TestCase};
use crate::nursery::Nursery;
use crate::smoke_test::authenticate_alice;

pub fn collect(suite: &mut TestSuite) {
suite.add(TestCase::new("session_cat",
Expand Down Expand Up @@ -325,15 +326,7 @@ async fn test_session_inner(
});

nursery.spawn(async move {
let res = client.auth_password("alice".into(), "alicealice".into()).await?;
if !matches!(res, makiko::AuthPasswordResult::Success) {
let res = client.auth_pubkey(
"alice".into(), keys::alice_ed25519(), &makiko::pubkey::SSH_ED25519).await?;
if !matches!(res, makiko::AuthPubkeyResult::Success) {
bail!("could not authenticate")
}
}

authenticate_alice(&client).await?;
let (session, session_rx) = client.open_session(makiko::ChannelConfig::default()).await?;
f(session, session_rx).await?;

Expand Down
22 changes: 13 additions & 9 deletions tests/compat/smoke_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,7 @@ async fn smoke_test(socket: TcpStream, config: makiko::ClientConfig) -> Result<(
});

nursery.spawn(enclose!{(nursery) async move {
let res = client.auth_password("alice".into(), "alicealice".into()).await?;
if !matches!(res, makiko::AuthPasswordResult::Success) {
let res = client.auth_pubkey(
"alice".into(), keys::alice_ed25519(), &makiko::pubkey::SSH_ED25519).await?;
if !matches!(res, makiko::AuthPubkeyResult::Success) {
bail!("could not authenticate")
}
}

authenticate_alice(&client).await?;
let (session, mut session_rx) = client.open_session(makiko::ChannelConfig::default()).await?;

let (stdout_tx, stdout_rx) = oneshot::channel();
Expand Down Expand Up @@ -154,3 +146,15 @@ async fn smoke_test(socket: TcpStream, config: makiko::ClientConfig) -> Result<(
drop(nursery);
nursery_stream.try_run().await
}

pub(super) async fn authenticate_alice(client: &makiko::Client) -> Result<()> {
let res = client.auth_password("alice".into(), "alicealice".into()).await?;
if !matches!(res, makiko::AuthPasswordResult::Success) {
let res = client.auth_pubkey(
"alice".into(), keys::alice_ed25519(), &makiko::pubkey::SSH_ED25519).await?;
if !matches!(res, makiko::AuthPubkeyResult::Success) {
bail!("could not authenticate")
}
}
Ok(())
}
Loading

0 comments on commit 758f817

Please sign in to comment.