Skip to content

Commit

Permalink
make it easy to consume a DmaStreamReader without copying (#474)
Browse files Browse the repository at this point in the history
* soft fail when crossing buffer boundaries in `get_buffer_aligned`

The `get_buffer_aligned` API could be better. It relies on user code to
be aware of buffer boundaries locations. This makes user code more
complicated and brittle because the buffer size will generally be a
configurable knob, subject to change.

Fear not, however, for there is a better way!

Instead of failing when reading too much in `get_buffer_aligned,` read
whatever we can from the current buffer. This allows the following:
* If the read fits inside a buffer, the user can consume the result
  directly without any copy;
* If the read crosses buffer boundaries and the user can consume partial
  results, it can do so in a loop without any copies;
* If the read crosses buffer boundaries and the user needs a complete
  result, `get_buffer_aligned` may be used trivially in a loop that
  concatenates the results in a reusable user-allocated buffer.

The following commit will replace the `AsyncRead::poll_read`
implementation with a simple loop calling `get_buffer_aligned.`

* enforce `max_pos` stream config in `poll_get_buffer_aligned`

We need to make sure the stream can't return any bytes beyond the
`max_pos` offset. Right now, `poll_read` respects this while
`poll_get_buffer_aligned` doesn't. This commit makes sure we are
consistent across the two.

* use `poll_get_buffer_aligned` in `DmaStreamReader::poll_read`

Massive simplification; Oh Yeah!!
  • Loading branch information
HippoBaro authored Nov 30, 2021
1 parent 5612d71 commit 0a00239
Showing 1 changed file with 69 additions and 154 deletions.
223 changes: 69 additions & 154 deletions glommio/src/io/dma_file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ use crate::{
io::{dma_file::align_down, read_result::ReadResult, DmaFile},
sys::DmaBuffer,
task,
ByteSliceExt,
ByteSliceMutExt,
GlommioError,
ResourceType,
};
use ahash::AHashMap;
use core::task::Waker;
Expand Down Expand Up @@ -258,29 +255,6 @@ impl DmaStreamReaderState {
(pos & self.buffer_size_mask) >> self.buffer_size_shift
}

fn copy_data(&mut self, pos: u64, result: &mut [u8]) -> usize {
let buffer_id = self.buffer_id(pos);
let in_buffer_offset = self.offset_of(pos);
if pos >= self.max_pos {
return 0;
}
let max_len = std::cmp::min(result.len(), (self.max_pos - pos) as usize);
let len: usize;

match self.buffermap.get(&buffer_id) {
None => {
panic!(
"Buffer not found. But we should only call this function after we verified \
that all buffers exist"
);
}
Some(buffer) => {
len = buffer.read_at(in_buffer_offset, &mut result[..max_len]);
}
}
len
}

// returns true if heir was no waker for this buffer_id
// otherwise, it replaces the existing one and returns false
fn add_waker(&mut self, buffer_id: u64, waker: Waker) -> bool {
Expand Down Expand Up @@ -519,25 +493,19 @@ impl DmaStreamReader {
}

/// Allows access to the buffer that holds the current position with no
/// extra copy
/// In order to use this API, one must guarantee that reading the specified
/// length may not cross into a different buffer. Users of this API are
/// expected to be aware of their buffer size (selectable in the
/// [`DmaStreamReaderBuilder`]) and act accordingly.
/// extra copy.
///
/// This function returns a [`ReadResult`]. It contains all the bytes read,
/// which can be less than the requested amount. Users are expected to call
/// this in a loop like they would [`AsyncRead::poll_read`].
///
/// The buffer is also not released until the returned [`ReadResult`] goes
/// The buffer is not released until the returned [`ReadResult`] goes
/// out of scope. So if you plan to keep this alive for a long time this
/// is probably the wrong API.
///
/// If EOF is hit while reading with this method, the number of bytes in the
/// returned buffer will be less than number requested.
///
/// Let's say you want to open a file and check if its header is sane: this
/// is a good API for that.
///
/// But if after such header there is an index that you want to keep in
/// memory, then you are probably better off with one of the methods
/// from [`AsyncReadExt`].
/// returned buffer will be less than number requested and the remaining
/// bytes will be 0.
///
/// # Examples
/// ```no_run
Expand All @@ -558,9 +526,7 @@ impl DmaStreamReader {
/// });
/// ```
///
/// [`DmaStreamReader`]: struct.DmaStreamReader.html
/// [`DmaStreamReaderBuilder`]: struct.DmaStreamReaderBuilder.html
/// [`AsyncReadExt`]: https://docs.rs/futures-lite/1.11.2/futures_lite/io/trait.AsyncReadExt.html
/// [`AsyncRead::poll_read`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html#tymethod.poll_read
/// [`ReadResult`]: struct.ReadResult.html
pub async fn get_buffer_aligned(&mut self, len: u64) -> Result<ReadResult> {
poll_fn(|cx| self.poll_get_buffer_aligned(cx, len)).await
Expand All @@ -569,60 +535,31 @@ impl DmaStreamReader {
/// A variant of [`get_buffer_aligned`] that can be called from a poll
/// function context.
///
/// Allows access to the buffer that holds the current position with no
/// extra copy
/// In order to use this API, one must guarantee that reading the specified
/// length may not cross into a different buffer. Users of this API are
/// expected to be aware of their buffer size (selectable in the
/// [`DmaStreamReaderBuilder`]) and act accordingly.
///
/// The buffer is also not released until the returned [`ReadResult`] goes
/// out of scope. So if you plan to keep this alive for a long time this
/// is probably the wrong API.
///
/// If EOF is hit while reading with this method, the number of bytes in the
/// returned buffer will be less than number requested.
///
/// Let's say you want to open a file and check if its header is sane: this
/// is a good API for that.
///
/// But if after such header there is an index that you want to keep in
/// memory, then you are probably better off with one of the methods
/// from [`AsyncReadExt`].
///
/// [`get_buffer_aligned`]: Self::get_buffer_aligned
/// [`DmaStreamReader`]: struct.DmaStreamReader.html
/// [`DmaStreamReaderBuilder`]: struct.DmaStreamReaderBuilder.html
/// [`AsyncReadExt`]: https://docs.rs/futures-lite/1.11.2/futures_lite/io/trait.AsyncReadExt.html
/// [`ReadResult`]: struct.ReadResult.html
pub fn poll_get_buffer_aligned(
&mut self,
cx: &mut Context<'_>,
len: u64,
mut len: u64,
) -> Poll<Result<ReadResult>> {
if len == 0 {
return Poll::Ready(Ok(ReadResult::empty_buffer()));
}

let (start_id, end_id, buffer_size) = {
let (start_id, buffer_len) = {
let state = self.state.borrow();
let start_id = state.buffer_id(self.current_pos);
let end_id = state.buffer_id(self.current_pos + len - 1);
(start_id, end_id, state.buffer_size)
let offset = state.offset_of(self.current_pos);

// enforce max_pos
if self.current_pos + len > state.max_pos {
len = state.max_pos - self.current_pos;
}

(start_id, (self.buffer_size - offset as u64).min(len))
};

if start_id != end_id {
return Poll::Ready(Err(GlommioError::<()>::WouldBlock(ResourceType::File(
format!(
"Reading {} bytes from position {} would cross a buffer boundary (Buffer size \
{})",
len, self.current_pos, buffer_size
),
))));
if len == 0 {
return Poll::Ready(Ok(ReadResult::empty_buffer()));
}

let x = ready!(self.poll_get_buffer(cx, len, start_id))?;
self.skip(len);
let x = ready!(self.poll_get_buffer(cx, buffer_len, start_id))?;
self.skip(x.len() as u64);
Poll::Ready(Ok(x))
}

Expand Down Expand Up @@ -666,68 +603,9 @@ impl AsyncRead for DmaStreamReader {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state = self.state.borrow_mut();
if let Some(err) = current_error!(state) {
return Poll::Ready(err);
}

let mut pos = self.current_pos;
if pos > state.max_pos {
return Poll::Ready(Ok(0));
}

let start = state.buffer_id(pos);
let end = state.buffer_id(pos + buf.len() as u64);

// special-casing the single buffer scenario helps small reads, as it allows
// us to do a single buffer lookup instead of N * 2;
if start == end {
match state.get_cached_buffer(&start).cloned() {
Some(buffer) => {
let max_len = std::cmp::min(buf.len(), (state.max_pos - pos) as usize);
let offset = state.offset_of(pos);
let bytes_copied = buffer.read_at(offset, &mut buf[..max_len]);
drop(state);
self.skip(bytes_copied as u64);
Poll::Ready(Ok(bytes_copied))
}
None => {
if state.add_waker(start, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
Poll::Pending
}
}
} else {
for id in start..=end {
match state.get_cached_buffer(&id) {
Some(buffer) => {
if (buffer.len() as u64) < self.buffer_size {
break;
}
}
None => {
if state.add_waker(id, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
return Poll::Pending;
}
}
}

let mut current_offset = 0;
while current_offset < buf.len() {
let bytes_copied = state.copy_data(pos, &mut buf[current_offset..]);
current_offset += bytes_copied;
pos += bytes_copied as u64;
if bytes_copied == 0 {
break;
}
}
drop(state);
self.skip(current_offset as u64);
Poll::Ready(Ok(current_offset))
}
let res = ready!(self.poll_get_buffer_aligned(cx, buf.len() as u64))?;
buf[..res.len()].copy_from_slice(&res);
Poll::Ready(Ok(res.len()))
}
}

Expand Down Expand Up @@ -1793,12 +1671,6 @@ mod test {
let buffer = reader.get_buffer_aligned(8).await.unwrap();
assert_eq!(buffer.len(), 8);
check_contents!(*buffer, 1004);

match reader.get_buffer_aligned(20).await {
Err(_) => {},
Ok(_) => panic!("Expected an error"),
}
assert_eq!(reader.current_pos(), 1012);
reader.skip((128 << 10) - 1012);
let eof_short_buffer = reader.get_buffer_aligned(4).await.unwrap();
assert_eq!(eof_short_buffer.len(), 2);
Expand All @@ -1807,6 +1679,49 @@ mod test {
reader.close().await.unwrap();
});

file_stream_read_test!(read_get_buffer_aligned_cross_boundaries, path, _k, file, _file_size: 2048, {
let mut reader = DmaStreamReaderBuilder::new(file)
.with_buffer_size(1024)
.build();

reader.skip(1022);

match reader.get_buffer_aligned(130).await {
Err(_) => panic!("Expected partial success"),
Ok(res) => {
assert_eq!(res.len(), 2);
check_contents!(*res, 1022);
},
}
assert_eq!(reader.current_pos(), 1024);

match reader.get_buffer_aligned(64).await {
Err(_) => panic!("Expected success"),
Ok(res) => {
assert_eq!(res.len(), 64);
check_contents!(*res, 1024);
},
}
assert_eq!(reader.current_pos(), 1088);

reader.skip(896);

// EOF
match reader.get_buffer_aligned(128).await {
Err(_) => panic!("Expected success"),
Ok(res) => {
assert_eq!(res.len(), 64);
check_contents!(*res, 1984);
},
}
assert_eq!(reader.current_pos(), 2048);

let eof = reader.get_buffer_aligned(64).await.unwrap();
assert_eq!(eof.len(), 0);

reader.close().await.unwrap();
});

file_stream_read_test!(read_get_buffer_aligned_zero_buffer, path, _k, file, _file_size: 131072, {
let mut reader = DmaStreamReaderBuilder::new(file)
.with_buffer_size(131072)
Expand Down

0 comments on commit 0a00239

Please sign in to comment.