diff --git a/tokio/src/io/util/read_line.rs b/tokio/src/io/util/read_line.rs index c5ee597486f..549314089ed 100644 --- a/tokio/src/io/util/read_line.rs +++ b/tokio/src/io/util/read_line.rs @@ -5,7 +5,6 @@ use std::future::Future; use std::io; use std::mem; use std::pin::Pin; -use std::str; use std::task::{Context, Poll}; cfg_io_util! { @@ -26,7 +25,7 @@ where { ReadLine { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, + bytes: mem::replace(buf, String::new()).into_bytes(), buf, read: 0, } @@ -39,20 +38,20 @@ pub(super) fn read_line_internal( bytes: &mut Vec, read: &mut usize, ) -> Poll> { - let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read))?; + match String::from_utf8(mem::replace(bytes, Vec::new())) { + Ok(string) => { + debug_assert!(buf.is_empty()); + *buf = string; + Poll::Ready(Ok(ret)) + } + Err(e) => { + *bytes = e.into_bytes(); + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - debug_assert_eq!(*read, 0); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } } } @@ -66,7 +65,14 @@ impl Future for ReadLine<'_, R> { bytes, read, } = &mut *self; - read_line_internal(Pin::new(reader), cx, buf, bytes, read) + let ret = ready!(read_line_internal(Pin::new(reader), cx, buf, bytes, read)); + if ret.is_err() { + // If an error occurred, put the data back, but only if it is valid utf-8. + if let Ok(string) = String::from_utf8(mem::replace(bytes, Vec::new())) { + **buf = string; + } + } + Poll::Ready(ret) } } diff --git a/tokio/tests/io_read_line.rs b/tokio/tests/io_read_line.rs index 57ae37cef3e..995aa298721 100644 --- a/tokio/tests/io_read_line.rs +++ b/tokio/tests/io_read_line.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; use std::io::Cursor; @@ -27,3 +28,50 @@ async fn read_line() { assert_eq!(n, 0); assert_eq!(buf, ""); } + +#[tokio::test] +async fn read_line_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld\nFizzBuz") + .read(b"z\n1\n2") + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "We say ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "Hello World\n".len()); + assert_eq!(line.as_str(), "We say Hello World\n"); + + line = "I solve ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "FizzBuzz\n".len()); + assert_eq!(line.as_str(), "I solve FizzBuzz\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(line.as_str(), "1\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(line.as_str(), "2"); +} + +#[tokio::test] +async fn read_line_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "FooHello Wor"); +} diff --git a/tokio/tests/io_read_until.rs b/tokio/tests/io_read_until.rs index 4e0e0d10d34..99a8990a12b 100644 --- a/tokio/tests/io_read_until.rs +++ b/tokio/tests/io_read_until.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; #[tokio::test] async fn read_until() { @@ -21,3 +22,50 @@ async fn read_until() { assert_eq!(n, 0); assert_eq!(buf, []); } + +#[tokio::test] +async fn read_until_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld#FizzBuz") + .read(b"z#1#2") + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"We say ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Hello World#".len()); + assert_eq!(chunk, b"We say Hello World#"); + + chunk = b"I solve ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, "FizzBuzz\n".len()); + assert_eq!(chunk, b"I solve FizzBuzz#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(chunk, b"1#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(chunk, b"2"); +} + +#[tokio::test] +async fn read_until_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"Foo".to_vec(); + let err = read.read_until(b'#', &mut chunk).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(chunk, b"FooHello Wor"); +}