Skip to content

Commit f3375ce

Browse files
committed
Improve type safety, extract identical code
Avoid fragility of tracking objects and their FDs separately.
1 parent f21d62a commit f3375ce

File tree

1 file changed

+71
-47
lines changed

1 file changed

+71
-47
lines changed

src/unix_term.rs

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use std::env;
22
use std::fmt::Display;
33
use std::fs;
4-
use std::io;
5-
use std::io::{BufRead, BufReader};
4+
use std::io::{self, BufRead, BufReader};
65
use std::mem;
7-
use std::os::unix::io::AsRawFd;
6+
use std::os::fd::{AsRawFd, RawFd};
87
use std::str;
98

109
#[cfg(not(target_os = "macos"))]
@@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*;
1817
pub(crate) const DEFAULT_WIDTH: u16 = 80;
1918

2019
#[inline]
21-
pub(crate) fn is_a_terminal(out: &Term) -> bool {
20+
pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool {
2221
unsafe { libc::isatty(out.as_raw_fd()) != 0 }
2322
}
2423

@@ -66,41 +65,76 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> {
6665
}
6766
}
6867

69-
pub(crate) fn read_secure() -> io::Result<String> {
70-
let mut f_tty;
71-
let fd = unsafe {
72-
if libc::isatty(libc::STDIN_FILENO) == 1 {
73-
f_tty = None;
74-
libc::STDIN_FILENO
68+
enum Input<T> {
69+
Stdin(io::Stdin),
70+
File(T),
71+
}
72+
73+
impl Input<fs::File> {
74+
fn new() -> io::Result<Self> {
75+
let stdin = io::stdin();
76+
if is_a_terminal(&stdin) {
77+
Ok(Input::Stdin(stdin))
7578
} else {
7679
let f = fs::OpenOptions::new()
7780
.read(true)
7881
.write(true)
7982
.open("/dev/tty")?;
80-
let fd = f.as_raw_fd();
81-
f_tty = Some(BufReader::new(f));
82-
fd
83+
Ok(Input::File(f))
8384
}
84-
};
85+
}
86+
}
87+
88+
impl Input<BufReader<fs::File>> {
89+
fn new() -> io::Result<Self> {
90+
Ok(match Input::<fs::File>::new()? {
91+
Input::Stdin(s) => Self::Stdin(s),
92+
Input::File(f) => Self::File(BufReader::new(f)),
93+
})
94+
}
95+
}
96+
97+
impl<T: BufRead> Input<T> {
98+
fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
99+
match self {
100+
Self::Stdin(s) => s.read_line(buf),
101+
Self::File(f) => f.read_line(buf),
102+
}
103+
}
104+
}
105+
106+
impl AsRawFd for Input<fs::File> {
107+
fn as_raw_fd(&self) -> RawFd {
108+
match self {
109+
Self::Stdin(s) => s.as_raw_fd(),
110+
Self::File(f) => f.as_raw_fd(),
111+
}
112+
}
113+
}
114+
115+
impl AsRawFd for Input<BufReader<fs::File>> {
116+
fn as_raw_fd(&self) -> RawFd {
117+
match self {
118+
Self::Stdin(s) => s.as_raw_fd(),
119+
Self::File(f) => f.get_ref().as_raw_fd(),
120+
}
121+
}
122+
}
123+
124+
pub(crate) fn read_secure() -> io::Result<String> {
125+
let mut input = Input::<BufReader<fs::File>>::new()?;
85126

86127
let mut termios = mem::MaybeUninit::uninit();
87-
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
128+
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
88129
let mut termios = unsafe { termios.assume_init() };
89130
let original = termios;
90131
termios.c_lflag &= !libc::ECHO;
91-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?;
132+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?;
92133
let mut rv = String::new();
93134

94-
let read_rv = if let Some(f) = &mut f_tty {
95-
f.read_line(&mut rv)
96-
} else {
97-
io::stdin().read_line(&mut rv)
98-
};
99-
100-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?;
135+
let read_rv = input.read_line(&mut rv);
101136

102-
// Ensure the fd is only closed after everything has been restored.
103-
drop(f_tty);
137+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?;
104138

105139
read_rv.map(|_| {
106140
let len = rv.trim_end_matches(&['\r', '\n'][..]).len();
@@ -109,7 +143,7 @@ pub(crate) fn read_secure() -> io::Result<String> {
109143
})
110144
}
111145

112-
fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
146+
fn poll_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
113147
let mut pollfd = libc::pollfd {
114148
fd,
115149
events: libc::POLLIN,
@@ -124,7 +158,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
124158
}
125159

126160
#[cfg(target_os = "macos")]
127-
fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
161+
fn select_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
128162
unsafe {
129163
let mut read_fd_set: libc::fd_set = mem::zeroed();
130164

@@ -156,7 +190,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
156190
}
157191
}
158192

159-
fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
193+
fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
160194
// There is a bug on macos that ttys cannot be polled, only select()
161195
// works. However given how problematic select is in general, we
162196
// normally want to use poll there too.
@@ -169,7 +203,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
169203
poll_fd(fd, timeout)
170204
}
171205

172-
fn read_single_char(fd: i32) -> io::Result<Option<char>> {
206+
fn read_single_char(fd: RawFd) -> io::Result<Option<char>> {
173207
// timeout of zero means that it will not block
174208
let is_ready = select_or_poll_term_fd(fd, 0)?;
175209

@@ -188,7 +222,7 @@ fn read_single_char(fd: i32) -> io::Result<Option<char>> {
188222
// Similar to libc::read. Read count bytes into slice buf from descriptor fd.
189223
// If successful, return the number of bytes read.
190224
// Will return an error if nothing was read, i.e when called at end of file.
191-
fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
225+
fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result<u8> {
192226
let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) };
193227
if read < 0 {
194228
Err(io::Error::last_os_error())
@@ -207,7 +241,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
207241
}
208242
}
209243

210-
fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
244+
fn read_single_key_impl(fd: RawFd) -> Result<Key, io::Error> {
211245
loop {
212246
match read_single_char(fd)? {
213247
Some('\x1b') => {
@@ -301,27 +335,17 @@ fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
301335
}
302336

303337
pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
304-
let tty_f;
305-
let fd = unsafe {
306-
if libc::isatty(libc::STDIN_FILENO) == 1 {
307-
libc::STDIN_FILENO
308-
} else {
309-
tty_f = fs::OpenOptions::new()
310-
.read(true)
311-
.write(true)
312-
.open("/dev/tty")?;
313-
tty_f.as_raw_fd()
314-
}
315-
};
338+
let input = Input::<fs::File>::new()?;
339+
316340
let mut termios = core::mem::MaybeUninit::uninit();
317-
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
341+
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
318342
let mut termios = unsafe { termios.assume_init() };
319343
let original = termios;
320344
unsafe { libc::cfmakeraw(&mut termios) };
321345
termios.c_oflag = original.c_oflag;
322-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?;
323-
let rv: io::Result<Key> = read_single_key_impl(fd);
324-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?;
346+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?;
347+
let rv: io::Result<Key> = read_single_key_impl(input.as_raw_fd());
348+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?;
325349

326350
// if the user hit ^C we want to signal SIGINT to outselves.
327351
if let Err(ref err) = rv {

0 commit comments

Comments
 (0)