Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make it easy to consume a DmaStreamReader without copying #474

Merged
merged 3 commits into from
Nov 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 69 additions & 154 deletions glommio/src/io/dma_file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ use crate::{
io::{dma_file::align_down, read_result::ReadResult, DmaFile},
sys::DmaBuffer,
task,
ByteSliceExt,
ByteSliceMutExt,
GlommioError,
ResourceType,
};
use ahash::AHashMap;
use core::task::Waker;
Expand Down Expand Up @@ -258,29 +255,6 @@ impl DmaStreamReaderState {
(pos & self.buffer_size_mask) >> self.buffer_size_shift
}

fn copy_data(&mut self, pos: u64, result: &mut [u8]) -> usize {
let buffer_id = self.buffer_id(pos);
let in_buffer_offset = self.offset_of(pos);
if pos >= self.max_pos {
return 0;
}
let max_len = std::cmp::min(result.len(), (self.max_pos - pos) as usize);
let len: usize;

match self.buffermap.get(&buffer_id) {
None => {
panic!(
"Buffer not found. But we should only call this function after we verified \
that all buffers exist"
);
}
Some(buffer) => {
len = buffer.read_at(in_buffer_offset, &mut result[..max_len]);
}
}
len
}

// returns true if heir was no waker for this buffer_id
// otherwise, it replaces the existing one and returns false
fn add_waker(&mut self, buffer_id: u64, waker: Waker) -> bool {
Expand Down Expand Up @@ -519,25 +493,19 @@ impl DmaStreamReader {
}

/// Allows access to the buffer that holds the current position with no
/// extra copy
/// In order to use this API, one must guarantee that reading the specified
/// length may not cross into a different buffer. Users of this API are
/// expected to be aware of their buffer size (selectable in the
/// [`DmaStreamReaderBuilder`]) and act accordingly.
/// extra copy.
///
/// This function returns a [`ReadResult`]. It contains all the bytes read,
/// which can be less than the requested amount. Users are expected to call
/// this in a loop like they would [`AsyncRead::poll_read`].
///
/// The buffer is also not released until the returned [`ReadResult`] goes
/// The buffer is not released until the returned [`ReadResult`] goes
/// out of scope. So if you plan to keep this alive for a long time this
/// is probably the wrong API.
///
/// If EOF is hit while reading with this method, the number of bytes in the
/// returned buffer will be less than number requested.
///
/// Let's say you want to open a file and check if its header is sane: this
/// is a good API for that.
///
/// But if after such header there is an index that you want to keep in
/// memory, then you are probably better off with one of the methods
/// from [`AsyncReadExt`].
/// returned buffer will be less than number requested and the remaining
/// bytes will be 0.
///
/// # Examples
/// ```no_run
Expand All @@ -558,9 +526,7 @@ impl DmaStreamReader {
/// });
/// ```
///
/// [`DmaStreamReader`]: struct.DmaStreamReader.html
/// [`DmaStreamReaderBuilder`]: struct.DmaStreamReaderBuilder.html
/// [`AsyncReadExt`]: https://docs.rs/futures-lite/1.11.2/futures_lite/io/trait.AsyncReadExt.html
/// [`AsyncRead::poll_read`]: https://docs.rs/futures-io/latest/futures_io/trait.AsyncRead.html#tymethod.poll_read
/// [`ReadResult`]: struct.ReadResult.html
pub async fn get_buffer_aligned(&mut self, len: u64) -> Result<ReadResult> {
poll_fn(|cx| self.poll_get_buffer_aligned(cx, len)).await
Expand All @@ -569,60 +535,31 @@ impl DmaStreamReader {
/// A variant of [`get_buffer_aligned`] that can be called from a poll
/// function context.
///
/// Allows access to the buffer that holds the current position with no
/// extra copy
/// In order to use this API, one must guarantee that reading the specified
/// length may not cross into a different buffer. Users of this API are
/// expected to be aware of their buffer size (selectable in the
/// [`DmaStreamReaderBuilder`]) and act accordingly.
///
/// The buffer is also not released until the returned [`ReadResult`] goes
/// out of scope. So if you plan to keep this alive for a long time this
/// is probably the wrong API.
///
/// If EOF is hit while reading with this method, the number of bytes in the
/// returned buffer will be less than number requested.
///
/// Let's say you want to open a file and check if its header is sane: this
/// is a good API for that.
///
/// But if after such header there is an index that you want to keep in
/// memory, then you are probably better off with one of the methods
/// from [`AsyncReadExt`].
///
/// [`get_buffer_aligned`]: Self::get_buffer_aligned
/// [`DmaStreamReader`]: struct.DmaStreamReader.html
/// [`DmaStreamReaderBuilder`]: struct.DmaStreamReaderBuilder.html
/// [`AsyncReadExt`]: https://docs.rs/futures-lite/1.11.2/futures_lite/io/trait.AsyncReadExt.html
/// [`ReadResult`]: struct.ReadResult.html
pub fn poll_get_buffer_aligned(
&mut self,
cx: &mut Context<'_>,
len: u64,
mut len: u64,
) -> Poll<Result<ReadResult>> {
if len == 0 {
return Poll::Ready(Ok(ReadResult::empty_buffer()));
}

let (start_id, end_id, buffer_size) = {
let (start_id, buffer_len) = {
let state = self.state.borrow();
let start_id = state.buffer_id(self.current_pos);
let end_id = state.buffer_id(self.current_pos + len - 1);
(start_id, end_id, state.buffer_size)
let offset = state.offset_of(self.current_pos);

// enforce max_pos
if self.current_pos + len > state.max_pos {
len = state.max_pos - self.current_pos;
}

(start_id, (self.buffer_size - offset as u64).min(len))
};

if start_id != end_id {
return Poll::Ready(Err(GlommioError::<()>::WouldBlock(ResourceType::File(
format!(
"Reading {} bytes from position {} would cross a buffer boundary (Buffer size \
{})",
len, self.current_pos, buffer_size
),
))));
if len == 0 {
return Poll::Ready(Ok(ReadResult::empty_buffer()));
}

let x = ready!(self.poll_get_buffer(cx, len, start_id))?;
self.skip(len);
let x = ready!(self.poll_get_buffer(cx, buffer_len, start_id))?;
self.skip(x.len() as u64);
Poll::Ready(Ok(x))
}

Expand Down Expand Up @@ -666,68 +603,9 @@ impl AsyncRead for DmaStreamReader {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state = self.state.borrow_mut();
if let Some(err) = current_error!(state) {
return Poll::Ready(err);
}

let mut pos = self.current_pos;
if pos > state.max_pos {
return Poll::Ready(Ok(0));
}

let start = state.buffer_id(pos);
let end = state.buffer_id(pos + buf.len() as u64);

// special-casing the single buffer scenario helps small reads, as it allows
// us to do a single buffer lookup instead of N * 2;
if start == end {
match state.get_cached_buffer(&start).cloned() {
Some(buffer) => {
let max_len = std::cmp::min(buf.len(), (state.max_pos - pos) as usize);
let offset = state.offset_of(pos);
let bytes_copied = buffer.read_at(offset, &mut buf[..max_len]);
drop(state);
self.skip(bytes_copied as u64);
Poll::Ready(Ok(bytes_copied))
}
None => {
if state.add_waker(start, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
Poll::Pending
}
}
} else {
for id in start..=end {
match state.get_cached_buffer(&id) {
Some(buffer) => {
if (buffer.len() as u64) < self.buffer_size {
break;
}
}
None => {
if state.add_waker(id, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
return Poll::Pending;
}
}
}

let mut current_offset = 0;
while current_offset < buf.len() {
let bytes_copied = state.copy_data(pos, &mut buf[current_offset..]);
current_offset += bytes_copied;
pos += bytes_copied as u64;
if bytes_copied == 0 {
break;
}
}
drop(state);
self.skip(current_offset as u64);
Poll::Ready(Ok(current_offset))
}
let res = ready!(self.poll_get_buffer_aligned(cx, buf.len() as u64))?;
buf[..res.len()].copy_from_slice(&res);
Poll::Ready(Ok(res.len()))
}
}

Expand Down Expand Up @@ -1793,12 +1671,6 @@ mod test {
let buffer = reader.get_buffer_aligned(8).await.unwrap();
assert_eq!(buffer.len(), 8);
check_contents!(*buffer, 1004);

match reader.get_buffer_aligned(20).await {
Err(_) => {},
Ok(_) => panic!("Expected an error"),
}
assert_eq!(reader.current_pos(), 1012);
reader.skip((128 << 10) - 1012);
let eof_short_buffer = reader.get_buffer_aligned(4).await.unwrap();
assert_eq!(eof_short_buffer.len(), 2);
Expand All @@ -1807,6 +1679,49 @@ mod test {
reader.close().await.unwrap();
});

file_stream_read_test!(read_get_buffer_aligned_cross_boundaries, path, _k, file, _file_size: 2048, {
let mut reader = DmaStreamReaderBuilder::new(file)
.with_buffer_size(1024)
.build();

reader.skip(1022);

match reader.get_buffer_aligned(130).await {
Err(_) => panic!("Expected partial success"),
Ok(res) => {
assert_eq!(res.len(), 2);
check_contents!(*res, 1022);
},
}
assert_eq!(reader.current_pos(), 1024);

match reader.get_buffer_aligned(64).await {
Err(_) => panic!("Expected success"),
Ok(res) => {
assert_eq!(res.len(), 64);
check_contents!(*res, 1024);
},
}
assert_eq!(reader.current_pos(), 1088);

reader.skip(896);

// EOF
match reader.get_buffer_aligned(128).await {
Err(_) => panic!("Expected success"),
Ok(res) => {
assert_eq!(res.len(), 64);
check_contents!(*res, 1984);
},
}
assert_eq!(reader.current_pos(), 2048);

let eof = reader.get_buffer_aligned(64).await.unwrap();
assert_eq!(eof.len(), 0);

reader.close().await.unwrap();
});

file_stream_read_test!(read_get_buffer_aligned_zero_buffer, path, _k, file, _file_size: 131072, {
let mut reader = DmaStreamReaderBuilder::new(file)
.with_buffer_size(131072)
Expand Down