Skip to content

Commit

Permalink
Handle pseudo-terminal in the example client
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasp committed Jul 17, 2022
1 parent 2bb98be commit d92ab78
Show file tree
Hide file tree
Showing 8 changed files with 496 additions and 68 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ enclose = "1.1"
env_logger = "0.9"
futures = "0.3"
regex = "1.5"
termios = "0.3"
rustix = {version = "0.35", features = ["termios"]}
tokio = {version = "1", features = ["full"]}

[features]
Expand Down
305 changes: 242 additions & 63 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use bytes::BytesMut;
use enclose::enclose;
use futures::future::{FutureExt as _, FusedFuture as _};
use regex::Regex;
use rustix::termios;
use std::collections::HashSet;
use std::fs;
use std::{env, fs};
use std::future::Future;
use std::os::unix::io::AsRawFd as _;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -46,6 +47,8 @@ fn run_main() -> Result<ExitCode> {
.arg(clap::Arg::new("command")
.takes_value(true)
.value_name("command"))
.arg(clap::Arg::new("want-tty").short('t')
.action(clap::ArgAction::SetTrue))
.get_matches();

let mut destination = Destination::default();
Expand All @@ -61,10 +64,11 @@ fn run_main() -> Result<ExitCode> {
.collect::<Result<Vec<_>>>()?;

let command = matches.get_one::<String>("command").cloned();
let want_tty = *matches.get_one::<bool>("want-tty").unwrap() || command.is_none();

let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all().build()?;
let exit_code = runtime.block_on(run_client(destination, keys, command))?;
let exit_code = runtime.block_on(run_client(destination, keys, command, want_tty))?;
runtime.shutdown_background();
Ok(exit_code)
}
Expand Down Expand Up @@ -109,7 +113,12 @@ fn read_key(path: &Path) -> Result<Key> {
Ok(Key { path: path.into(), data, decoded })
}

async fn run_client(destination: Destination, keys: Vec<Key>, command: Option<String>) -> Result<ExitCode> {
async fn run_client(
destination: Destination,
keys: Vec<Key>,
command: Option<String>,
want_tty: bool,
) -> Result<ExitCode> {
let host = destination.host
.context("please specify the host to connect to")?;
let username = destination.username
Expand Down Expand Up @@ -141,7 +150,7 @@ async fn run_client(destination: Destination, keys: Vec<Key>, command: Option<St

let config = makiko::ChannelConfig::default();
let (session, session_rx) = client.open_session(config).await?;
let exit_code = interact(session, session_rx, command).await?;
let exit_code = interact(session, session_rx, command, want_tty).await?;
client.disconnect(makiko::DisconnectError::by_app())?;
Result::<ExitCode>::Ok(exit_code)
}});
Expand Down Expand Up @@ -308,73 +317,33 @@ async fn authenticate(client: &makiko::Client, username: String, keys: Vec<Key>)
bail!("no authentication method succeeded")
}

async fn ask_yes_no(prompt: &str) -> Result<bool> {
let mut stdout = tokio::io::stdout();
stdout.write_all(format!("{} [y/N]: ", prompt).as_bytes()).await?;
stdout.flush().await?;

let mut stdin = tokio::io::stdin();
let mut yes = false;
loop {
let c = stdin.read_u8().await?;
if c == b'\r' || c == b'\n' {
break
} else if c.is_ascii_whitespace() {
continue
} else if c == b'y' || c == b'Y' {
yes = true;
} else {
yes = false;
}
}

Ok(yes)
}

async fn ask_for_password(prompt: &str) -> Result<String> {
let mut stdout = tokio::io::stdout();
stdout.write_all(format!("{}: ", prompt).as_bytes()).await?;
stdout.flush().await?;

let mut stdin = tokio::io::stdin();
let orig_termios = termios::Termios::from_fd(stdin.as_raw_fd())?;

let mut termios = orig_termios;
termios.c_lflag &= !termios::ECHO;
termios::tcsetattr(stdin.as_raw_fd(), termios::TCSADRAIN, &termios)?;

let mut password = Vec::new();
loop {
let c = stdin.read_u8().await?;
if password.is_empty() && c.is_ascii_whitespace() {
continue
} else if c == b'\r' || c == b'\n' {
break
} else {
password.push(c);
}
}
stdout.write_u8(b'\n').await?;

termios::tcsetattr(stdin.as_raw_fd(), termios::TCSADRAIN, &orig_termios)?;
Ok(std::str::from_utf8(&password)?.into())
}

async fn interact(
session: makiko::Session,
mut session_rx: makiko::SessionReceiver,
command: Option<String>,
want_tty: bool,
) -> Result<ExitCode> {
let mut pty_req = None;
let mut orig_tio = None;
if want_tty && termios::isatty(std::io::stdin()) {
pty_req = Some(get_pty_request()?);
orig_tio = Some(enter_raw_mode()?);
}

let recv_task = tokio::task::spawn(async move {
let mut stdout = tokio::io::stdout();
let mut stderr = tokio::io::stderr();

while let Some(event) = session_rx.recv().await? {
match event {
makiko::SessionEvent::StdoutData(data) =>
stdout.write_all(&data).await?,
makiko::SessionEvent::StderrData(data) =>
stderr.write_all(&data).await?,
makiko::SessionEvent::StdoutData(data) => {
stdout.write_all(&data).await?;
stdout.flush().await?;
},
makiko::SessionEvent::StderrData(data) => {
stderr.write_all(&data).await?;
stderr.flush().await?;
},
makiko::SessionEvent::ExitStatus(status) => {
log::info!("received exit status {}", status);
return Ok(ExitCode::from(status as u8))
Expand All @@ -393,6 +362,10 @@ async fn interact(
});

let send_task = tokio::task::spawn(enclose!{(session) async move {
if let Some(pty_req) = pty_req.as_ref() {
session.request_pty(&pty_req)?.want_reply().await?;
}

if let Some(command) = command {
session.exec(command.as_bytes())?.want_reply().await?;
} else {
Expand All @@ -401,9 +374,22 @@ async fn interact(

let mut stdin = tokio::io::stdin();
let mut stdin_buf = BytesMut::new();
while stdin.read_buf(&mut stdin_buf).await? != 0 {
session.send_stdin(stdin_buf.split().freeze()).await?;
let mut winch_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())?;
loop {
tokio::select!{
res = stdin.read_buf(&mut stdin_buf) => {
if res? > 0 {
session.send_stdin(stdin_buf.split().freeze()).await?
} else {
break
}
},
Some(()) = winch_stream.recv() => {
session.window_change(&get_window_change()?)?;
},
}
}

session.send_eof().await?;
Result::<()>::Ok(())
}});
Expand All @@ -412,7 +398,12 @@ async fn interact(
let mut send_fut = AbortOnDrop(send_task).map(|res| res.expect("sending task panicked")).fuse();
loop {
tokio::select!{
recv_res = &mut recv_fut => return recv_res,
recv_res = &mut recv_fut => {
if let Some(tio) = orig_tio {
leave_raw_mode(tio);
}
return recv_res
},
send_res = &mut send_fut => send_res?,
};
}
Expand All @@ -434,3 +425,191 @@ impl<T> Drop for AbortOnDrop<T> {
}
}

async fn ask_yes_no(prompt: &str) -> Result<bool> {
let mut stdout = tokio::io::stdout();
stdout.write_all(format!("{} [y/N]: ", prompt).as_bytes()).await?;
stdout.flush().await?;

let mut stdin = tokio::io::stdin();
let mut yes = false;
loop {
let c = stdin.read_u8().await?;
if c == b'\r' || c == b'\n' {
break
} else if c.is_ascii_whitespace() {
continue
} else if c == b'y' || c == b'Y' {
yes = true;
} else {
yes = false;
}
}

Ok(yes)
}

async fn ask_for_password(prompt: &str) -> Result<String> {
let mut stdout = tokio::io::stdout();
stdout.write_all(format!("{}: ", prompt).as_bytes()).await?;
stdout.flush().await?;

let mut stdin = tokio::io::stdin();
let stdin_raw = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) };
let orig_tio = termios::tcgetattr(stdin_raw)?;

let mut tio = orig_tio;
tio.c_lflag &= !termios::ECHO;
termios::tcsetattr(stdin_raw, termios::OptionalActions::Drain, &tio)?;

let mut password = Vec::new();
loop {
let c = stdin.read_u8().await?;
if password.is_empty() && c.is_ascii_whitespace() {
continue
} else if c == b'\r' || c == b'\n' {
break
} else {
password.push(c);
}
}
stdout.write_u8(b'\n').await?;

termios::tcsetattr(stdin_raw, termios::OptionalActions::Drain, &orig_tio)?;
Ok(std::str::from_utf8(&password)?.into())
}

fn enter_raw_mode() -> Result<termios::Termios> {
// this code is shamelessly copied from OpenSSH

let stdin = tokio::io::stdin();
let stdin_raw = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) };

let orig_tio = termios::tcgetattr(stdin_raw)?;
let mut tio = orig_tio;

tio.c_iflag |= termios::IGNPAR;
tio.c_iflag &= !(termios::ISTRIP | termios::INLCR | termios::IGNCR | termios::ICRNL
| termios::IXON | termios::IXANY | termios::IXOFF | termios::IUCLC);
tio.c_lflag &= !(termios::ISIG | termios::ICANON | termios::ECHO | termios::ECHOE
| termios::ECHOK | termios::ECHONL | termios::IEXTEN);
tio.c_oflag &= !termios::OPOST;
tio.c_cc[termios::VMIN] = 1;
tio.c_cc[termios::VTIME] = 0;

log::debug!("entering terminal raw mode");
termios::tcsetattr(stdin_raw, termios::OptionalActions::Drain, &tio)?;
Ok(orig_tio)
}

fn leave_raw_mode(tio: termios::Termios) {
let stdin = tokio::io::stdin();
let stdin_raw = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) };
let _ = termios::tcsetattr(stdin_raw, termios::OptionalActions::Drain, &tio);
log::debug!("left terminal raw mode");
}

fn get_window_change() -> Result<makiko::WindowChange> {
let winsize = termios::tcgetwinsize(std::io::stdin())?;
Ok(makiko::WindowChange {
width: winsize.ws_col as u32,
height: winsize.ws_row as u32,
width_px: winsize.ws_xpixel as u32,
height_px: winsize.ws_ypixel as u32,
})
}

fn get_pty_request() -> Result<makiko::PtyRequest> {
// this code is shamelessly copied from OpenSSH

let mut req = makiko::PtyRequest::default();
req.term = env::var("TERM").unwrap_or(String::new());

let stdin = tokio::io::stdin();
let stdin_raw = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) };
let winsize = termios::tcgetwinsize(stdin_raw)?;
req.width = winsize.ws_col as u32;
req.height = winsize.ws_row as u32;
req.width_px = winsize.ws_xpixel as u32;
req.height_px = winsize.ws_ypixel as u32;

let tio = termios::tcgetattr(stdin_raw)?;

macro_rules! tty_char {
($name:ident, $op:ident) => {
let value = tio.c_cc[termios::$name];
let value = if value == 0 { 255 } else { value as u32 };
req.modes.add(makiko::codes::terminal_mode::$op, value);
};
($name:ident) => {
tty_char!($name, $name)
};
}

macro_rules! tty_mode {
($name:ident, $field:ident, $op:ident) => {
let value = (tio.$field & termios::$name) != 0;
let value = value as u32;
req.modes.add(makiko::codes::terminal_mode::$op, value);
};
($name:ident, $field:ident) => {
tty_mode!($name, $field, $name)
};
}

tty_char!(VINTR);
tty_char!(VQUIT);
tty_char!(VERASE);
tty_char!(VKILL);
tty_char!(VEOF);
tty_char!(VEOL);
tty_char!(VEOL2);
tty_char!(VSTART);
tty_char!(VSTOP);
tty_char!(VSUSP);
tty_char!(VREPRINT);
tty_char!(VWERASE);
tty_char!(VLNEXT);
tty_char!(VDISCARD);

tty_mode!(IGNPAR, c_iflag);
tty_mode!(PARMRK, c_iflag);
tty_mode!(INPCK, c_iflag);
tty_mode!(ISTRIP, c_iflag);
tty_mode!(INLCR, c_iflag);
tty_mode!(IGNCR, c_iflag);
tty_mode!(ICRNL, c_iflag);
tty_mode!(IUCLC, c_iflag);
tty_mode!(IXON, c_iflag);
tty_mode!(IXANY, c_iflag);
tty_mode!(IXOFF, c_iflag);
tty_mode!(IMAXBEL, c_iflag);
tty_mode!(IUTF8, c_iflag);

tty_mode!(ISIG, c_lflag);
tty_mode!(ICANON, c_lflag);
tty_mode!(XCASE, c_lflag);
tty_mode!(ECHO, c_lflag);
tty_mode!(ECHOE, c_lflag);
tty_mode!(ECHOK, c_lflag);
tty_mode!(ECHONL, c_lflag);
tty_mode!(NOFLSH, c_lflag);
tty_mode!(TOSTOP, c_lflag);
tty_mode!(IEXTEN, c_lflag);
tty_mode!(ECHOCTL, c_lflag);
tty_mode!(ECHOKE, c_lflag);
tty_mode!(PENDIN, c_lflag);

tty_mode!(OPOST, c_oflag);
tty_mode!(OLCUC, c_oflag);
tty_mode!(ONLCR, c_oflag);
tty_mode!(OCRNL, c_oflag);
tty_mode!(ONOCR, c_oflag);
tty_mode!(ONLRET, c_oflag);

tty_mode!(CS7, c_cflag);
tty_mode!(CS8, c_cflag);
tty_mode!(PARENB, c_cflag);
tty_mode!(PARODD, c_cflag);

Ok(req)
}
Loading

0 comments on commit d92ab78

Please sign in to comment.