From 7e111c13bb78ce80b3007aa325839a47790a3341 Mon Sep 17 00:00:00 2001 From: allada Date: Tue, 16 Nov 2021 09:54:29 -0800 Subject: [PATCH] Add buf_channel that will be used to help transport bytes around --- .../tests/bytestream_server_test.rs | 2 +- util/BUILD | 28 +- util/buf_channel.rs | 334 ++++++++++++++++++ util/error.rs | 27 ++ util/tests/buf_channel_test.rs | 231 ++++++++++++ 5 files changed, 620 insertions(+), 2 deletions(-) create mode 100644 util/buf_channel.rs create mode 100644 util/tests/buf_channel_test.rs diff --git a/cas/grpc_service/tests/bytestream_server_test.rs b/cas/grpc_service/tests/bytestream_server_test.rs index 1699b1d53..448835811 100644 --- a/cas/grpc_service/tests/bytestream_server_test.rs +++ b/cas/grpc_service/tests/bytestream_server_test.rs @@ -320,7 +320,7 @@ pub mod read_tests { let result = result.err_tip(|| "Expected result to be ready")?; let expected_err_str = concat!( "status: NotFound, message: \"Hash 0123456789abcdef000000000000000000000000000000000123456789abcdef ", - "not found : Error retrieving data from store : Sender disconnected : Error reading data from ", + "not found : Error retrieving data from store : --- : Sender disconnected : Error reading data from ", "underlying store\", details: [], metadata: MetadataMap { headers: {} }", ); assert_eq!( diff --git a/util/BUILD b/util/BUILD index 7b0cad578..349785b34 100644 --- a/util/BUILD +++ b/util/BUILD @@ -6,11 +6,11 @@ rust_library( name = "error", srcs = ["error.rs"], deps = [ + "//proto", "//third_party:hex", "//third_party:prost", "//third_party:tokio", "//third_party:tonic", - "//proto", ], visibility = ["//visibility:public"], ) @@ -77,6 +77,19 @@ rust_library( visibility = ["//visibility:public"], ) +rust_library( + name = "buf_channel", + srcs = ["buf_channel.rs"], + deps = [ + "//third_party:bytes", + "//third_party:futures", + "//third_party:tokio", + "//third_party:tokio_util", + ":error", + ], + visibility = ["//visibility:public"], +) + rust_library( name = "fastcdc", srcs = ["fastcdc.rs"], @@ -142,3 +155,16 @@ rust_test( ":retry", ], ) + +rust_test( + name = "buf_channel_test", + srcs = ["tests/buf_channel_test.rs"], + deps = [ + "//third_party:bytes", + "//third_party:futures", + "//third_party:pretty_assertions", + "//third_party:tokio", + ":buf_channel", + ":error", + ], +) diff --git a/util/buf_channel.rs b/util/buf_channel.rs new file mode 100644 index 000000000..94d03cd3a --- /dev/null +++ b/util/buf_channel.rs @@ -0,0 +1,334 @@ +// Copyright 2021 Nathan (Blaise) Bruer. All rights reserved. + +use std::pin::Pin; +use std::task::Poll; + +use bytes::{BufMut, Bytes, BytesMut}; +use futures::{task::Context, Future, Stream, StreamExt}; +use tokio::sync::{mpsc, oneshot}; +pub use tokio_util::io::StreamReader; + +use error::{make_err, Code, Error, ResultExt}; + +/// Create a channel pair that can be used to transport buffer objects around to +/// different components. This wrapper is used because the streams give some +/// utility like managing EOF in a more friendly way, ensure if no EOF is received +/// it will send an error to the receiver channel before shutting down and count +/// the number of bytes sent. +pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) { + // We allow up to 2 items in the buffer at any given time. There is no major + // reason behind this magic number other than thinking it will be nice to give + // a little time for another thread to wake up and consume data if another + // thread is pumping large amounts of data into the channel. + let (tx, rx) = mpsc::channel(2); + let (close_tx, close_rx) = oneshot::channel(); + ( + DropCloserWriteHalf { + tx: Some(tx), + bytes_written: 0, + close_rx, + }, + DropCloserReadHalf { + rx: rx, + partial: None, + close_tx: Some(close_tx), + close_after_size: u64::MAX, + }, + ) +} + +/// Writer half of the pair. +pub struct DropCloserWriteHalf { + tx: Option>>, + bytes_written: u64, + /// Receiver channel used to know the error (or success) value of the + /// receiver end's drop status (ie: if the receiver dropped unexpectedly). + close_rx: oneshot::Receiver>, +} + +impl DropCloserWriteHalf { + /// Sends data over the channel to the receiver. + pub async fn send(&mut self, buf: Bytes) -> Result<(), Error> { + let tx = self + .tx + .as_ref() + .ok_or_else(|| make_err!(Code::Internal, "Tried to send while stream is closed"))?; + let buf_len = buf.len() as u64; + assert!(buf_len != 0, "Cannot send EOF in send(). Instead use send_eof()"); + let result = tx + .send(Ok(buf)) + .await + .map_err(|_| make_err!(Code::Internal, "Failed to write to data, receiver disconnected")); + if result.is_err() { + // Close our channel to prevent drop() from spawning a task. + self.tx = None; + } + self.bytes_written += buf_len; + result + } + + /// Sends an EOF (End of File) message to the receiver which will gracefully let the + /// stream know it has no more data. This will close the stream. + pub async fn send_eof(&mut self) -> Result<(), Error> { + assert!(self.tx.is_some(), "Tried to send an EOF when pipe is broken"); + self.tx = None; + + // The final result will be provided in this oneshot channel. + Pin::new(&mut self.close_rx) + .await + .map_err(|_| make_err!(Code::Internal, "Receiver went away before receiving EOF"))? + } + + /// Forwards data from this writer to a reader. This is an efficient way to bind a writer + /// and reader together to just forward the data on. + pub async fn forward(&mut self, mut reader: S, forward_eof: bool) -> Result<(), Error> + where + S: Stream> + Send + Unpin, + { + loop { + match reader.next().await { + Some(maybe_chunk) => { + let chunk = maybe_chunk.err_tip(|| "Failed to forward message")?; + if chunk.len() == 0 { + // Don't send EOF here. We instead rely on None result to be EOF. + continue; + } + self.send(chunk).await?; + } + None => { + if forward_eof { + self.send_eof().await?; + } + break; + } + } + } + Ok(()) + } + + /// Returns the number of bytes written so far. This does not mean the receiver received + /// all of the bytes written to the stream so far. + pub fn get_bytes_written(&self) -> u64 { + self.bytes_written + } + + /// Returns if the pipe was broken. This is good for determining if the reader broke the + /// pipe or the writer broke the pipe, since this will only return true if the pipe was + /// broken by the writer. + pub fn is_pipe_broken(&self) -> bool { + self.tx.is_none() + } +} + +impl Drop for DropCloserWriteHalf { + /// This will notify the reader of an error if we did not send an EOF. + fn drop(&mut self) { + if let Some(tx) = self.tx.take() { + // If we do not notify the receiver of the premature close of the stream (ie: without EOF) + // we could end up with the receiver thinking everything is good and saving this bad data. + tokio::spawn(async move { + let _ = tx + .send(Err( + make_err!(Code::Internal, "Writer was dropped before EOF was sent",), + )) + .await; // Nowhere to send failure to write here. + }); + } + } +} + +/// Reader half of the pair. +pub struct DropCloserReadHalf { + rx: mpsc::Receiver>, + /// Represents a partial chunk of data. This is used when we only wanted + /// to take a part of the chunk in the stream and leave the rest. + partial: Option>, + /// A channel used to notify the sender that we are closed (with error). + close_tx: Option>>, + /// Once this number of bytes is sent the stream will be considered closed. + /// This is a work around for cases when we never receive an EOF because the + /// reader's future is dropped because it got the exact amount of data and + /// will never poll more. This prevents the `drop()` handle from sending an + /// error to our writer that we dropped the stream before receiving an EOF + /// if we know the exact amount of data we will receive in this stream. + close_after_size: u64, +} + +impl DropCloserReadHalf { + /// Receive a chunk of data. + pub async fn recv(&mut self) -> Result { + let maybe_chunk = match self.partial.take() { + Some(result_bytes) => Some(result_bytes), + None => self.rx.recv().await, + }; + match maybe_chunk { + Some(Ok(chunk)) => { + let chunk_len = chunk.len() as u64; + assert!(chunk_len != 0, "Chunk should never be EOF, expected None in this case"); + assert!( + self.close_after_size >= chunk_len, + "Received too much data. This only happens when `close_after_size` is set." + ); + self.close_after_size -= chunk_len; + if self.close_after_size == 0 { + assert!(self.close_tx.is_some(), "Expected stream to not be closed"); + self.close_tx.take().unwrap().send(Ok(())).map_err(|_| { + make_err!(Code::Internal, "Failed to send closing ok message to write with size") + })?; + } + Ok(chunk) + } + + Some(Err(e)) => Err(e), + + // None is a safe EOF received. + None => { + // Notify our sender that we received the EOF with success. + if let Some(close_tx) = self.close_tx.take() { + close_tx + .send(Ok(())) + .map_err(|_| make_err!(Code::Internal, "Failed to send closing ok message to write"))?; + } + Ok(Bytes::new()) + } + } + } + + /// Sets the number of bytes before the stream will be considered closed. + pub fn set_close_after_size(&mut self, size: u64) { + self.close_after_size = size; + } + + /// Utility function that will collect all the data of the stream into a Bytes struct. + /// This method is optimized to reduce copies when possible. + pub async fn collect_all_with_size_hint(mut self, size_hint: usize) -> Result { + let (first_chunk, second_chunk) = { + // This is an optimization for when there's only one chunk and an EOF. + // This prevents us from any copies and we just shuttle the bytes. + let first_chunk = self + .recv() + .await + .err_tip(|| "Failed to recv first chunk in collect_all_with_size_hint")?; + + if first_chunk.len() == 0 { + return Ok(first_chunk); + } + + let second_chunk = self + .recv() + .await + .err_tip(|| "Failed to recv second chunk in collect_all_with_size_hint")?; + + if second_chunk.len() == 0 { + return Ok(first_chunk); + } + (first_chunk, second_chunk) + }; + + let mut buf = BytesMut::with_capacity(size_hint); + buf.put(first_chunk); + buf.put(second_chunk); + + loop { + let chunk = self + .recv() + .await + .err_tip(|| "Failed to recv in collect_all_with_size_hint")?; + if chunk.len() == 0 { + break; // EOF. + } + buf.put(chunk); + } + Ok(buf.freeze()) + } + + /// Takes exactly `size` number of bytes from the stream and returns them. + /// This means the stream will keep polling until either an EOF is received or + /// `size` bytes are received and concat them all together then return them. + /// This method is optimized to reduce copies when possible. + pub async fn take(&mut self, size: usize) -> Result { + fn populate_partial_if_needed( + current_size: usize, + desired_size: usize, + chunk: &mut Bytes, + partial: &mut Option>, + ) { + if current_size + chunk.len() <= desired_size { + return; + } + assert!(partial.is_none(), "Partial should have been consumed during the recv()"); + let local_partial = chunk.split_off(desired_size - current_size); + *partial = if local_partial.len() == 0 { + None + } else { + Some(Ok(local_partial)) + }; + } + + let (first_chunk, second_chunk) = { + // This is an optimization for a relatively common case when the first chunk in the + // stream satisfies all the requirements to fill this `take()`. + // This will us from needing to copy the data into a new buffer and instead we can + // just forward on the original Bytes object. If we need more than the first chunk + // we will then go the slow path and actually copy our data. + let mut first_chunk = self.recv().await.err_tip(|| "During first buf_channel::take()")?; + populate_partial_if_needed(0, size, &mut first_chunk, &mut self.partial); + if first_chunk.len() == 0 || first_chunk.len() >= size { + assert!( + first_chunk.len() == 0 || first_chunk.len() == size, + "Length should be exactly size here" + ); + return Ok(first_chunk); + } + + let mut second_chunk = self.recv().await.err_tip(|| "During second buf_channel::take()")?; + if second_chunk.len() == 0 { + assert!( + first_chunk.len() <= size, + "Length should never be larger than size here" + ); + return Ok(first_chunk); + } + populate_partial_if_needed(first_chunk.len(), size, &mut second_chunk, &mut self.partial); + (first_chunk, second_chunk) + }; + let mut output = BytesMut::with_capacity(size); + output.put(first_chunk); + output.put(second_chunk); + + loop { + let mut chunk = self.recv().await.err_tip(|| "During buf_channel::take()")?; + if chunk.len() == 0 { + break; // EOF. + } + + populate_partial_if_needed(output.len(), size, &mut chunk, &mut self.partial); + + output.put(chunk); + + if output.len() >= size { + assert!(output.len() == size); // Length should never be larger than size here. + break; + } + } + Ok(output.freeze()) + } +} + +impl Stream for DropCloserReadHalf { + type Item = Result; + + // TODO(blaise.bruer) This is not very efficient as we are creating a new future on every + // poll() call. It might be better to use a waker. + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Box::pin(self.recv()).as_mut().poll(cx).map(|result| match result { + Ok(bytes) => { + if bytes.len() == 0 { + return None; + } + Some(Ok(bytes)) + } + Err(e) => Some(Err(e.to_std_err())), + }) + } +} diff --git a/util/error.rs b/util/error.rs index 3ebfa5dae..92e5f46a5 100644 --- a/util/error.rs +++ b/util/error.rs @@ -45,9 +45,15 @@ impl Error { pub fn merge>(mut self, other: E) -> Self { let mut other: Error = other.into(); + // This will help with knowing which messages are tied to different errors. + self.messages.push("---".to_string()); self.messages.append(&mut other.messages); self } + + pub fn to_std_err(self) -> std::io::Error { + std::io::Error::new(self.code.into(), self.messages.join(" : ")) + } } impl std::error::Error for Error {} @@ -147,6 +153,7 @@ pub trait ResultExt { S: std::string::ToString, F: (std::ops::FnOnce(&Error) -> (Code, S)) + Sized; + #[inline] fn err_tip(self, tip_fn: F) -> Result where Self: Sized, @@ -165,6 +172,7 @@ pub trait ResultExt { } impl> ResultExt for Result { + #[inline] fn err_tip_with_code(self, tip_fn: F) -> Result where Self: Sized, @@ -188,6 +196,8 @@ impl> ResultExt for Result { let mut e: Error = e.into(); if let Err(other_err) = other { let mut other_err: Error = other_err.into(); + // This will help with knowing which messages are tied to different errors. + e.messages.push("---".to_string()); e.messages.append(&mut other_err.messages); } return Err(e); @@ -197,6 +207,7 @@ impl> ResultExt for Result { } impl ResultExt for Option { + #[inline] fn err_tip_with_code(self, tip_fn: F) -> Result where Self: Sized, @@ -315,3 +326,19 @@ impl From for Code { } } } + +impl From for std::io::ErrorKind { + fn from(kind: Code) -> Self { + match kind { + Code::Aborted => std::io::ErrorKind::Interrupted, + Code::AlreadyExists => std::io::ErrorKind::AlreadyExists, + Code::DeadlineExceeded => std::io::ErrorKind::TimedOut, + Code::Internal => std::io::ErrorKind::Other, + Code::InvalidArgument => std::io::ErrorKind::InvalidInput, + Code::NotFound => std::io::ErrorKind::NotFound, + Code::PermissionDenied => std::io::ErrorKind::PermissionDenied, + Code::Unavailable => std::io::ErrorKind::ConnectionRefused, + _ => std::io::ErrorKind::Other, + } + } +} diff --git a/util/tests/buf_channel_test.rs b/util/tests/buf_channel_test.rs new file mode 100644 index 000000000..c1dfe3aa9 --- /dev/null +++ b/util/tests/buf_channel_test.rs @@ -0,0 +1,231 @@ +// Copyright 2021 Nathan (Blaise) Bruer. All rights reserved. + +use bytes::Bytes; +use tokio::try_join; + +use buf_channel::make_buf_channel_pair; +use error::{make_err, Code, Error, ResultExt}; + +#[cfg(test)] +mod buf_channel_tests { + use super::*; + use pretty_assertions::assert_eq; // Must be declared in every module. + + const DATA1: &str = "foo"; + const DATA2: &str = "bar"; + + #[tokio::test] + async fn smoke_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + assert_eq!(rx.recv().await?, DATA1); + assert_eq!(rx.recv().await?, DATA2); + Ok(()) + } + + #[tokio::test] + async fn bytes_written_test() -> Result<(), Error> { + let (mut tx, _rx) = make_buf_channel_pair(); + tx.send(DATA1.into()).await?; + assert_eq!(tx.get_bytes_written(), DATA1.len() as u64); + tx.send(DATA2.into()).await?; + assert_eq!(tx.get_bytes_written(), (DATA1.len() + DATA2.len()) as u64); + Ok(()) + } + + #[tokio::test] + async fn sending_eof_sets_pipe_broken_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + assert_eq!(tx.is_pipe_broken(), false); + tx.send_eof().await?; + assert_eq!(tx.is_pipe_broken(), true); + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + assert_eq!(rx.recv().await?, Bytes::from(DATA1)); + assert_eq!(rx.recv().await?, Bytes::new()); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn rx_closes_before_eof_sends_err_to_tx_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + // Send one message. + tx.send(DATA1.into()).await?; + // Try to send EOF, but expect error because receiver will be dropped without taking it. + assert_eq!( + tx.send_eof().await, + Err(make_err!(Code::Internal, "Receiver went away before receiving EOF")) + ); + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + // Receive first message. + assert_eq!(rx.recv().await?, Bytes::from(DATA1)); + // Now drop rx without receiving EOF. + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn set_close_after_size_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + rx.set_close_after_size(DATA1.len() as u64); + assert_eq!(rx.recv().await?, Bytes::from(DATA1)); + // Now there's an EOF, but we are going to drop instead of taking it. + // We should not send an error to the tx. + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn collect_all_with_size_hint_test() -> Result<(), Error> { + let (mut tx, rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + assert_eq!( + rx.collect_all_with_size_hint(0).await?, + Bytes::from(format!("{}{}{}{}", DATA1, DATA2, DATA1, DATA2)) + ); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + /// Test to ensure data is optimized so that the exact same pointer is received + /// when calling `collect_all_with_size_hint` when a copy is not needed. + #[tokio::test] + async fn collect_all_with_size_hint_is_optimized_test() -> Result<(), Error> { + let (mut tx, rx) = make_buf_channel_pair(); + let sent_data = Bytes::from(DATA1); + let send_data_ptr = sent_data.as_ptr(); + let tx_fut = async move { + tx.send(sent_data).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + // Because data is 1 chunk and an EOF, we should not need to copy + // and should get the exact same pointer. + let received_data = rx.collect_all_with_size_hint(0).await?; + assert_eq!(received_data.as_ptr(), send_data_ptr); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn take_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + let all_data = Bytes::from(format!("{}{}{}{}", DATA1, DATA2, DATA1, DATA2)); + assert_eq!(rx.take(1).await?, all_data.slice(0..1)); + assert_eq!(rx.take(3).await?, all_data.slice(1..4)); + assert_eq!(rx.take(4).await?, all_data.slice(4..8)); + // Last chunk take too much data and expect EOF to be hit. + assert_eq!(rx.take(100).await?, all_data.slice(8..12)); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + /// This test ensures that when we are taking just one message in the stream, + /// we don't need to concat the data together and instead return a view to + /// the original data instead of making a copy. + #[tokio::test] + async fn take_optimized_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let first_chunk = Bytes::from(DATA1); + let first_chunk_ptr = first_chunk.as_ptr(); + let tx_fut = async move { + tx.send(first_chunk).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + assert_eq!(rx.take(1).await?.as_ptr(), first_chunk_ptr); + assert_eq!(rx.take(100).await?.as_ptr(), unsafe { first_chunk_ptr.add(1) }); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn simple_stream_test() -> Result<(), Error> { + use futures::StreamExt; + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send(DATA1.into()).await?; + tx.send(DATA2.into()).await?; + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + assert_eq!(rx.next().await.map(|v| v.err_tip(|| "")), Some(Ok(Bytes::from(DATA1)))); + assert_eq!(rx.next().await.map(|v| v.err_tip(|| "")), Some(Ok(Bytes::from(DATA2)))); + assert_eq!(rx.next().await.map(|v| v.err_tip(|| "")), Some(Ok(Bytes::from(DATA1)))); + assert_eq!(rx.next().await.map(|v| v.err_tip(|| "")), Some(Ok(Bytes::from(DATA2)))); + assert_eq!(rx.next().await.map(|v| v.err_tip(|| "")), None); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } + + #[tokio::test] + async fn rx_gets_error_if_tx_drops_test() -> Result<(), Error> { + let (mut tx, mut rx) = make_buf_channel_pair(); + let tx_fut = async move { + tx.send(DATA1.into()).await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + assert_eq!(rx.recv().await?, Bytes::from(DATA1)); + assert_eq!( + rx.recv().await, + Err(make_err!(Code::Internal, "Writer was dropped before EOF was sent")) + ); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + Ok(()) + } +}