Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: fix panic in read_line #2541

Merged
merged 2 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokio/src/io/util/lines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ where
let me = self.project();

let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?;
debug_assert_eq!(*me.read, 0);

if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));
Expand Down
71 changes: 49 additions & 22 deletions tokio/src/io/util/read_line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::future::Future;
use std::io;
use std::mem;
use std::pin::Pin;
use std::str;
use std::task::{Context, Poll};

cfg_io_util! {
Expand All @@ -14,45 +13,72 @@ cfg_io_util! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadLine<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut String,
bytes: Vec<u8>,
/// This is the buffer we were provided. It will be replaced with an empty string
/// while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
/// The actual allocation of the string is moved into a vector instead.
buf: Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}

pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R>
pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadLine {
reader,
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
buf,
buf: mem::replace(string, String::new()).into_bytes(),
output: string,
read: 0,
}
}

fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_read: usize) {
let original_len = vector.len() - num_bytes_read;
vector.truncate(original_len);
*output = String::from_utf8(vector).expect("The original data must be valid utf-8.");
}

pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
reader: Pin<&mut R>,
cx: &mut Context<'_>,
buf: &mut String,
bytes: &mut Vec<u8>,
output: &mut String,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read));
if str::from_utf8(&bytes).is_err() {
Poll::Ready(ret.and_then(|_| {
Err(io::Error::new(
let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));

// At this point both buf and output are empty. The allocation is in utf8_res.

debug_assert!(buf.is_empty());
match (io_res, utf8_res) {
(Ok(num_bytes), Ok(string)) => {
debug_assert_eq!(*read, 0);
*output = string;
Poll::Ready(Ok(num_bytes))
}
(Err(io_err), Ok(string)) => {
*output = string;
Poll::Ready(Err(io_err))
}
(Ok(num_bytes), Err(utf8_err)) => {
debug_assert_eq!(*read, 0);
put_back_original_data(output, utf8_err.into_bytes(), num_bytes);

Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
}))
} else {
debug_assert!(buf.is_empty());
debug_assert_eq!(*read, 0);
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
Poll::Ready(ret)
)))
}
(Err(io_err), Err(utf8_err)) => {
put_back_original_data(output, utf8_err.into_bytes(), *read);

Poll::Ready(Err(io_err))
}
}
}

Expand All @@ -62,11 +88,12 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
output,
buf,
bytes,
read,
} = &mut *self;
read_line_internal(Pin::new(reader), cx, buf, bytes, read)

read_line_internal(Pin::new(reader), cx, output, buf, read)
}
}

Expand Down
17 changes: 10 additions & 7 deletions tokio/src/io/util/read_until.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,30 @@ use std::task::{Context, Poll};

cfg_io_util! {
/// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method.
/// The delimeter is included in the resulting vector.
taiki-e marked this conversation as resolved.
Show resolved Hide resolved
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadUntil<'a, R: ?Sized> {
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}

pub(crate) fn read_until<'a, R>(
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
) -> ReadUntil<'a, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadUntil {
reader,
byte,
delimeter,
buf,
read: 0,
}
Expand All @@ -37,14 +40,14 @@ where
pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>(
mut reader: Pin<&mut R>,
cx: &mut Context<'_>,
byte: u8,
delimeter: u8,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
loop {
let (done, used) = {
let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
if let Some(i) = memchr::memchr(byte, available) {
if let Some(i) = memchr::memchr(delimeter, available) {
buf.extend_from_slice(&available[..=i]);
(true, i + 1)
} else {
Expand All @@ -66,11 +69,11 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
byte,
delimeter,
buf,
read,
} = &mut *self;
read_until_internal(Pin::new(reader), cx, *byte, buf, read)
read_until_internal(Pin::new(reader), cx, *delimeter, buf, read)
}
}

Expand Down
2 changes: 2 additions & 0 deletions tokio/src/io/util/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ where
let n = ready!(read_until_internal(
me.reader, cx, *me.delim, me.buf, me.read,
))?;
// read_until_internal resets me.read to zero once it finds the delimeter
debug_assert_eq!(*me.read, 0);

if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));
Expand Down
82 changes: 80 additions & 2 deletions tokio/tests/io_read_line.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};

use std::io::Cursor;

Expand All @@ -27,3 +28,80 @@ async fn read_line() {
assert_eq!(n, 0);
assert_eq!(buf, "");
}

#[tokio::test]
async fn read_line_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld\nFizzBuz")
.read(b"z\n1\n2")
.build();

let mut read = BufReader::new(mock);

let mut line = "We say ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "Hello World\n".len());
assert_eq!(line.as_str(), "We say Hello World\n");

line = "I solve ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "FizzBuzz\n".len());
assert_eq!(line.as_str(), "I solve FizzBuzz\n");

line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(line.as_str(), "1\n");

line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(line.as_str(), "2");
}

#[tokio::test]
async fn read_line_invalid_utf8() {
let mock = Builder::new().read(b"Hello Wor\xffld.\n").build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::InvalidData);
assert_eq!(err.to_string(), "stream did not contain valid UTF-8");
assert_eq!(line.as_str(), "Foo");
}

#[tokio::test]
async fn read_line_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "FooHello Wor");
}

#[tokio::test]
async fn read_line_fail_and_utf8_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"\xff\xff\xff")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "Foo");
}
55 changes: 53 additions & 2 deletions tokio/tests/io_read_until.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};

#[tokio::test]
async fn read_until() {
Expand All @@ -21,3 +22,53 @@ async fn read_until() {
assert_eq!(n, 0);
assert_eq!(buf, []);
}

#[tokio::test]
async fn read_until_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld#Fizz\xffBuz")
.read(b"z#1#2")
.build();

let mut read = BufReader::new(mock);

let mut chunk = b"We say ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Hello World#".len());
assert_eq!(chunk, b"We say Hello World#");

chunk = b"I solve ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Fizz\xffBuzz\n".len());
assert_eq!(chunk, b"I solve Fizz\xffBuzz#");

chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(chunk, b"1#");

chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(chunk, b"2");
}

#[tokio::test]
async fn read_until_fail() {
let mock = Builder::new()
.read(b"Hello \xffWor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut chunk = b"Foo".to_vec();
let err = read
.read_until(b'#', &mut chunk)
.await
.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(chunk, b"FooHello \xffWor");
}