Skip to content

Commit

Permalink
use poll_get_buffer_aligned in DmaStreamReader::poll_read
Browse files Browse the repository at this point in the history
Massive simplification; Oh Yeah!!
  • Loading branch information
HippoBaro committed Nov 29, 2021
1 parent a5b8f1f commit cd12a8c
Showing 1 changed file with 3 additions and 86 deletions.
89 changes: 3 additions & 86 deletions glommio/src/io/dma_file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::{
io::{dma_file::align_down, read_result::ReadResult, DmaFile},
sys::DmaBuffer,
task,
ByteSliceExt,
ByteSliceMutExt,
};
use ahash::AHashMap;
Expand Down Expand Up @@ -256,29 +255,6 @@ impl DmaStreamReaderState {
(pos & self.buffer_size_mask) >> self.buffer_size_shift
}

fn copy_data(&mut self, pos: u64, result: &mut [u8]) -> usize {
let buffer_id = self.buffer_id(pos);
let in_buffer_offset = self.offset_of(pos);
if pos >= self.max_pos {
return 0;
}
let max_len = std::cmp::min(result.len(), (self.max_pos - pos) as usize);
let len: usize;

match self.buffermap.get(&buffer_id) {
None => {
panic!(
"Buffer not found. But we should only call this function after we verified \
that all buffers exist"
);
}
Some(buffer) => {
len = buffer.read_at(in_buffer_offset, &mut result[..max_len]);
}
}
len
}

// returns true if heir was no waker for this buffer_id
// otherwise, it replaces the existing one and returns false
fn add_waker(&mut self, buffer_id: u64, waker: Waker) -> bool {
Expand Down Expand Up @@ -637,68 +613,9 @@ impl AsyncRead for DmaStreamReader {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state = self.state.borrow_mut();
if let Some(err) = current_error!(state) {
return Poll::Ready(err);
}

let mut pos = self.current_pos;
if pos > state.max_pos {
return Poll::Ready(Ok(0));
}

let start = state.buffer_id(pos);
let end = state.buffer_id(pos + buf.len() as u64);

// special-casing the single buffer scenario helps small reads, as it allows
// us to do a single buffer lookup instead of N * 2;
if start == end {
match state.get_cached_buffer(&start).cloned() {
Some(buffer) => {
let max_len = std::cmp::min(buf.len(), (state.max_pos - pos) as usize);
let offset = state.offset_of(pos);
let bytes_copied = buffer.read_at(offset, &mut buf[..max_len]);
drop(state);
self.skip(bytes_copied as u64);
Poll::Ready(Ok(bytes_copied))
}
None => {
if state.add_waker(start, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
Poll::Pending
}
}
} else {
for id in start..=end {
match state.get_cached_buffer(&id) {
Some(buffer) => {
if (buffer.len() as u64) < self.buffer_size {
break;
}
}
None => {
if state.add_waker(id, cx.waker().clone()) {
state.fill_buffer(self.state.clone(), self.file.clone());
}
return Poll::Pending;
}
}
}

let mut current_offset = 0;
while current_offset < buf.len() {
let bytes_copied = state.copy_data(pos, &mut buf[current_offset..]);
current_offset += bytes_copied;
pos += bytes_copied as u64;
if bytes_copied == 0 {
break;
}
}
drop(state);
self.skip(current_offset as u64);
Poll::Ready(Ok(current_offset))
}
let (res, _) = ready!(self.poll_get_buffer_aligned(cx, buf.len() as u64))?;
buf[..res.len()].copy_from_slice(&res);
Poll::Ready(Ok(res.len()))
}
}

Expand Down

0 comments on commit cd12a8c

Please sign in to comment.