Skip to content

Remove BytesSource and refactor write_source #2230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions quinn-proto/src/connection/send_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl SendBuffer {

/// Append application data to the end of the stream
pub(super) fn write(&mut self, data: Bytes) {
if data.is_empty() {
return;
}

self.unacked_len += data.len();
self.offset += data.len() as u64;
self.unacked_segments.push_back(data);
Expand Down
51 changes: 41 additions & 10 deletions quinn-proto/src/connection/streams/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ use recv::Recv;
pub use recv::{Chunks, ReadError, ReadableError};

mod send;
pub(crate) use send::{ByteSlice, BytesArray};
use send::{BytesSource, Send, SendState};
pub use send::{FinishError, WriteError, Written};
use send::{Send, SendState};

mod state;
#[allow(unreachable_pub)] // fuzzing only
Expand Down Expand Up @@ -221,7 +220,11 @@ impl<'a> SendStream<'a> {
///
/// Returns the number of bytes successfully written.
pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
self.write_source(|limit, chunks| {
let prefix = &data[..limit.min(data.len())];
chunks.push(prefix.to_vec().into());
prefix.len()
})
}

/// Send data on the given stream
Expand All @@ -231,10 +234,38 @@ impl<'a> SendStream<'a> {
/// [`Written::chunks`] will not count this chunk as fully written. However
/// the chunk will be advanced and contain only non-written data after the call.
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
self.write_source(&mut BytesArray::from_chunks(data))
self.write_source(|limit, chunks| {
let mut written = Written::default();
for chunk in data {
let prefix = chunk.split_to(chunk.len().min(limit - written.bytes));
written.bytes += prefix.len();
chunks.push(prefix);

if chunk.is_empty() {
written.chunks += 1;
}

debug_assert!(written.bytes <= limit);
if written.bytes == limit {
break;
}
}
written
})
}

fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
/// Send data on the given stream
///
/// The `source` callback is invoked with the number of bytes that can be written immediately,
/// as well as an initially empty `&mut Vec<Bytes>` to which it can push bytes to write. If the
/// callback pushes a total number of bytes less than or equal to the provided limit, it is
/// guaranteed they will all be written. If it provides more bytes than this, it is guaranteed
/// that a prefix of the provided cumulative bytes will be written equal in length to the
/// provided limit.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to the excess bytes? Would it be simpler to require (and even assert) that the limit is respected?

fn write_source<T>(
&mut self,
source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Straw alternative proposal: impl FnOnce(usize) -> I where I: IntoIter<Item=Bytes>. Let callers worry about their own return values via side-effects. This makes for a less leaky abstraction, saves a Vec, and relieves us from having to plumb a T back to the caller.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: I: IntoIter<Item=Bytes>:

I'm contemplating that, and I may experiment with it. You've mentioned the pros, the cons include:

  • More complex generics. Harder to read, adds compile time(?).
  • Probably trickier to use (I imagine I would change write_chunks to use std::iter::from_fn or something).
  • Application code able to panic in the middle of proto looping through the chunks, whereas currently it can only panic before that loop. Might be trickier to analyze.
  • May be in tension with your above suggestion to panic if the user tries to write too many bytes, because it makes it impossible to confirm up-front how many bytes the user is submitting (unless we have the user return an IntoIterator and then just collect it into a Vec, which seems like the worst of both worlds in many ways).

I'm kind of leaning against.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: "Let callers worry about their own return values via side-effects":

Here's the diff I would make to do that. It does not remove any lines of code from write_source/Send::write, it adds 2 lines of code in net, it forces me to think about the subtleties of closure variable capturing stuff, and it makes the callers do this awkward .map(|()| written) thing. Conversely, plumbing it through acts as a nice type-level proof of when the callback will or won't be called, making the code easier to reason about. Especially if in the future we expose it publicly, as in #2231.

Diff
diff --git a/quinn-proto/src/connection/streams/mod.rs b/quinn-proto/src/connection/streams/mod.rs
index 70f51e39..137cf370 100644
--- a/quinn-proto/src/connection/streams/mod.rs
+++ b/quinn-proto/src/connection/streams/mod.rs
@@ -220,11 +220,13 @@ impl<'a> SendStream<'a> {
     ///
     /// Returns the number of bytes successfully written.
     pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
+        let mut written = 0;
         self.write_source(|limit, chunks| {
             let prefix = &data[..limit.min(data.len())];
             chunks.push(prefix.to_vec().into());
-            prefix.len()
+            written = prefix.len();
         })
+        .map(|()| written)
     }
 
     /// Send data on the given stream
@@ -234,8 +236,8 @@ impl<'a> SendStream<'a> {
     /// [`Written::chunks`] will not count this chunk as fully written. However
     /// the chunk will be advanced and contain only non-written data after the call.
     pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
+        let mut written = Written::default();
         self.write_source(|limit, chunks| {
-            let mut written = Written::default();
             for chunk in data {
                 let prefix = chunk.split_to(chunk.len().min(limit - written.bytes));
                 written.bytes += prefix.len();
@@ -250,8 +252,8 @@ impl<'a> SendStream<'a> {
                     break;
                 }
             }
-            written
         })
+        .map(|()| written)
     }
 
     /// Send data on the given stream
@@ -262,10 +264,10 @@ impl<'a> SendStream<'a> {
     /// guaranteed they will all be written. If it provides more bytes than this, it is guaranteed
     /// that a prefix of the provided cumulative bytes will be written equal in length to the
     /// provided limit.
-    fn write_source<T>(
+    fn write_source(
         &mut self,
-        source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
-    ) -> Result<T, WriteError> {
+        source: impl FnOnce(usize, &mut Vec<Bytes>),
+    ) -> Result<(), WriteError> {
         if self.conn_state.is_closed() {
             trace!(%self.id, "write blocked; connection draining");
             return Err(WriteError::Blocked);
@@ -295,14 +297,14 @@ impl<'a> SendStream<'a> {
         }
 
         let was_pending = stream.is_pending();
-        let (written, source_output) = stream.write(source, limit)?;
+        let written = stream.write(source, limit)?;
         self.state.data_sent += written as u64;
         self.state.unacked_data += written as u64;
         trace!(stream = %self.id, "wrote {} bytes", written);
         if !was_pending {
             self.state.pending.push_pending(self.id, stream.priority);
         }
-        Ok(source_output)
+        Ok(())
     }
 
     /// Check if this stream was stopped, get the reason if it was
diff --git a/quinn-proto/src/connection/streams/send.rs b/quinn-proto/src/connection/streams/send.rs
index 52a9b714..2217b988 100644
--- a/quinn-proto/src/connection/streams/send.rs
+++ b/quinn-proto/src/connection/streams/send.rs
@@ -52,11 +52,11 @@ impl Send {
         }
     }
 
-    pub(super) fn write<T>(
+    pub(super) fn write(
         &mut self,
-        source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
+        source: impl FnOnce(usize, &mut Vec<Bytes>),
         limit: u64,
-    ) -> Result<(usize, T), WriteError> {
+    ) -> Result<usize, WriteError> {
         if !self.is_writable() {
             return Err(WriteError::ClosedStream);
         }
@@ -70,7 +70,7 @@ impl Send {
         let limit = limit.min(budget) as usize;
 
         debug_assert!(self.chunks.is_empty());
-        let source_output = source(limit, &mut self.chunks);
+        source(limit, &mut self.chunks);
 
         let mut written = 0;
 
@@ -85,7 +85,7 @@ impl Send {
             }
         }
 
-        Ok((written, source_output))
+        Ok(written)
     }
 
     /// Update stream state due to a reset sent by the local application

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer it with the diff. Avoiding the generic is worth the 2 extra lines to me and IMO distributes responsibility in a better way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it makes the callers do this awkward .map(|()| written) thing

I think this will look much less weird if you phrase it as ?; Ok(written).

) -> Result<T, WriteError> {
if self.conn_state.is_closed() {
trace!(%self.id, "write blocked; connection draining");
return Err(WriteError::Blocked);
Expand Down Expand Up @@ -264,14 +295,14 @@ impl<'a> SendStream<'a> {
}

let was_pending = stream.is_pending();
let written = stream.write(source, limit)?;
self.state.data_sent += written.bytes as u64;
self.state.unacked_data += written.bytes as u64;
trace!(stream = %self.id, "wrote {} bytes", written.bytes);
let (written, source_output) = stream.write(source, limit)?;
self.state.data_sent += written as u64;
self.state.unacked_data += written as u64;
trace!(stream = %self.id, "wrote {} bytes", written);
if !was_pending {
self.state.pending.push_pending(self.id, stream.priority);
}
Ok(written)
Ok(source_output)
}

/// Check if this stream was stopped, get the reason if it was
Expand Down
230 changes: 19 additions & 211 deletions quinn-proto/src/connection/streams/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub(super) struct Send {
pub(super) connection_blocked: bool,
/// The reason the peer wants us to stop, if `STOP_SENDING` was received
pub(super) stop_reason: Option<VarInt>,
/// Reusable buf for usage within `write`--empty between calls to `self.write`
chunks: Vec<Bytes>,
}

impl Send {
Expand All @@ -27,6 +29,7 @@ impl Send {
fin_pending: false,
connection_blocked: false,
stop_reason: None,
chunks: Vec::new(),
})
}

Expand All @@ -49,11 +52,11 @@ impl Send {
}
}

pub(super) fn write<S: BytesSource>(
pub(super) fn write<T>(
&mut self,
source: &mut S,
source: impl FnOnce(usize, &mut Vec<Bytes>) -> T,
limit: u64,
) -> Result<Written, WriteError> {
) -> Result<(usize, T), WriteError> {
if !self.is_writable() {
return Err(WriteError::ClosedStream);
}
Expand All @@ -64,23 +67,25 @@ impl Send {
if budget == 0 {
return Err(WriteError::Blocked);
}
let mut limit = limit.min(budget) as usize;
let limit = limit.min(budget) as usize;

let mut result = Written::default();
loop {
let (chunk, chunks_consumed) = source.pop_chunk(limit);
result.chunks += chunks_consumed;
result.bytes += chunk.len();
debug_assert!(self.chunks.is_empty());
let source_output = source(limit, &mut self.chunks);

if chunk.is_empty() {
let mut written = 0;

for mut chunk in self.chunks.drain(..) {
let prefix = chunk.split_to(chunk.len().min(limit - written));
written += prefix.len();
self.pending.write(prefix);

debug_assert!(written <= limit);
if written == limit {
break;
}

limit -= chunk.len();
self.pending.write(chunk);
}

Ok(result)
Ok((written, source_output))
}

/// Update stream state due to a reset sent by the local application
Expand Down Expand Up @@ -143,106 +148,6 @@ impl Send {
}
}

/// A [`BytesSource`] implementation for `&'a mut [Bytes]`
///
/// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to
/// a configured limit.
pub(crate) struct BytesArray<'a> {
/// The wrapped slice of `Bytes`
chunks: &'a mut [Bytes],
/// The amount of chunks consumed from this source
consumed: usize,
}

impl<'a> BytesArray<'a> {
pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
Self {
chunks,
consumed: 0,
}
}
}

impl BytesSource for BytesArray<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
// The loop exists to skip empty chunks while still marking them as
// consumed
let mut chunks_consumed = 0;

while self.consumed < self.chunks.len() {
let chunk = &mut self.chunks[self.consumed];

if chunk.len() <= limit {
let chunk = std::mem::take(chunk);
self.consumed += 1;
chunks_consumed += 1;
if chunk.is_empty() {
continue;
}
return (chunk, chunks_consumed);
} else if limit > 0 {
let chunk = chunk.split_to(limit);
return (chunk, chunks_consumed);
} else {
break;
}
}

(Bytes::new(), chunks_consumed)
}
}

/// A [`BytesSource`] implementation for `&[u8]`
///
/// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily
/// created from a reference. This allows to defer the allocation until it is
/// known how much data needs to be copied.
pub(crate) struct ByteSlice<'a> {
/// The wrapped byte slice
data: &'a [u8],
}

impl<'a> ByteSlice<'a> {
pub(crate) fn from_slice(data: &'a [u8]) -> Self {
Self { data }
}
}

impl BytesSource for ByteSlice<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let limit = limit.min(self.data.len());
if limit == 0 {
return (Bytes::new(), 0);
}

let chunk = Bytes::from(self.data[..limit].to_owned());
self.data = &self.data[chunk.len()..];

let chunks_consumed = usize::from(self.data.is_empty());
(chunk, chunks_consumed)
}
}

/// A source of one or more buffers which can be converted into `Bytes` buffers on demand
///
/// The purpose of this data type is to defer conversion as long as possible,
/// so that no heap allocation is required in case no data is writable.
pub(super) trait BytesSource {
/// Returns the next chunk from the source of owned chunks.
///
/// This method will consume parts of the source.
/// Calling it will yield `Bytes` elements up to the configured `limit`.
///
/// The method returns a tuple:
/// - The first item is the yielded `Bytes` element. The element will be
/// empty if the limit is zero or no more data is available.
/// - The second item returns how many complete chunks inside the source had
/// had been consumed. This can be less than 1, if a chunk inside the
/// source had been truncated in order to adhere to the limit. It can also
/// be more than 1, if zero-length chunks had been skipped.
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
}

/// Indicates how many bytes and chunks had been transferred in a write operation
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Written {
Expand Down Expand Up @@ -303,100 +208,3 @@ pub enum FinishError {
#[error("closed stream")]
ClosedStream,
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn bytes_array() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut chunks = [
Bytes::from_static(b""),
Bytes::from_static(b"Hello "),
Bytes::from_static(b"Wo"),
Bytes::from_static(b""),
Bytes::from_static(b"r"),
Bytes::from_static(b"ld"),
Bytes::from_static(b""),
Bytes::from_static(b" 12345678"),
Bytes::from_static(b"9 ABCDE"),
Bytes::from_static(b"F"),
Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
];
let num_chunks = chunks.len();
let last_chunk_len = chunks[chunks.len() - 1].len();

let mut array = BytesArray::from_chunks(&mut chunks);

let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;

if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}

assert_eq!(&buf[..], &full[..limit]);

if limit == full.len() {
// Full consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks);
// Since there are empty chunks, we consume more than there are popped
assert_eq!(chunks_consumed, chunks_popped + 3);
} else if limit > full.len() - last_chunk_len {
// Partial consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks - 1);
assert_eq!(chunks_consumed, chunks_popped + 2);
}
}
}

#[test]
fn byte_slice() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut array = ByteSlice::from_slice(&full[..]);

let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;

if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}

assert_eq!(&buf[..], &full[..limit]);
if limit != 0 {
assert_eq!(chunks_popped, 1);
} else {
assert_eq!(chunks_popped, 0);
}

if limit == full.len() {
assert_eq!(chunks_consumed, 1);
} else {
assert_eq!(chunks_consumed, 0);
}
}
}
}