diff --git a/Cargo.lock b/Cargo.lock index 96e183dd11..685ef6579b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,7 @@ dependencies = [ "prometheus 0.12.0", "prost", "proto 1.0.0-alpha01", + "rand 0.8.5", "serde", "serde_derive", "skiplist", diff --git a/analytic_engine/Cargo.toml b/analytic_engine/Cargo.toml index 491cfb0e7d..9cbc994d75 100644 --- a/analytic_engine/Cargo.toml +++ b/analytic_engine/Cargo.toml @@ -48,3 +48,4 @@ common_types = { workspace = true, features = ["test"] } common_util = { workspace = true, features = ["test"] } env_logger = { workspace = true } wal = { workspace = true, features = ["test"] } +rand = "0.8.5" diff --git a/analytic_engine/src/sst/parquet/async_reader.rs b/analytic_engine/src/sst/parquet/async_reader.rs index 2505fddc17..2163b98ef0 100644 --- a/analytic_engine/src/sst/parquet/async_reader.rs +++ b/analytic_engine/src/sst/parquet/async_reader.rs @@ -20,7 +20,7 @@ use common_types::{ use common_util::{runtime::Runtime, time::InstantExt}; use datafusion::datasource::file_format; use futures::{future::BoxFuture, FutureExt, Stream, StreamExt, TryFutureExt}; -use log::{debug, error, info}; +use log::{error, info}; use object_store::{ObjectMeta, ObjectStoreRef, Path}; use parquet::{ arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder, ProjectionMask}, @@ -457,8 +457,13 @@ impl Stream for RecordBatchReceiver { let cur_rx = self.rx_group.get_mut(cur_rx_idx).unwrap(); let poll_result = cur_rx.poll_recv(cx); - self.cur_rx_idx = (self.cur_rx_idx + 1) % self.rx_group.len(); - poll_result + match poll_result { + Poll::Ready(result) => { + self.cur_rx_idx = (self.cur_rx_idx + 1) % self.rx_group.len(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } } fn size_hint(&self) -> (usize, Option) { @@ -536,3 +541,113 @@ impl<'a> SstReader for ThreadedReader<'a> { }) as _) } } + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + task::{Context, Poll}, + time::Duration, + }; + + use futures::{Stream, StreamExt}; + use tokio::sync::mpsc::{self, Receiver, Sender}; + + struct MockReceivers { + rx_group: Vec>, + cur_rx_idx: usize, + } + + impl Stream for MockReceivers { + type Item = u32; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let cur_rx_idx = self.cur_rx_idx; + // `cur_rx_idx` is impossible to be out-of-range, because it is got by round + // robin. + let cur_rx = self.rx_group.get_mut(cur_rx_idx).unwrap(); + let poll_result = cur_rx.poll_recv(cx); + + match poll_result { + Poll::Ready(result) => { + self.cur_rx_idx = (self.cur_rx_idx + 1) % self.rx_group.len(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } + } + + struct MockRandomSenders { + tx_group: Vec>, + test_datas: Vec>, + } + + impl MockRandomSenders { + fn start_to_send(&mut self) { + while !self.tx_group.is_empty() { + let tx = self.tx_group.pop().unwrap(); + let test_data = self.test_datas.pop().unwrap(); + tokio::spawn(async move { + for datum in test_data { + let random_millis = rand::random::() % 30; + tokio::time::sleep(Duration::from_millis(random_millis)).await; + tx.send(datum).await.unwrap(); + } + }); + } + } + } + + fn gen_test_data(amount: usize) -> Vec { + (0..amount) + .into_iter() + .map(|_| rand::random::()) + .collect() + } + + // We mock a thread model same as the one in `ThreadedReader` to check its + // validity. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_simulated_threaded_reader() { + let test_data = gen_test_data(123); + let expected = test_data.clone(); + let channel_cap_per_sub_reader = 10; + let reader_num = 5; + let (tx_group, rx_group): (Vec<_>, Vec<_>) = (0..reader_num) + .into_iter() + .map(|_| mpsc::channel::(channel_cap_per_sub_reader)) + .unzip(); + + // Partition datas. + let chunk_len = reader_num; + let mut test_data_chunks = vec![Vec::new(); chunk_len]; + for (idx, datum) in test_data.into_iter().enumerate() { + let chunk_idx = idx % chunk_len; + test_data_chunks.get_mut(chunk_idx).unwrap().push(datum); + } + + // Start senders. + let mut mock_senders = MockRandomSenders { + tx_group, + test_datas: test_data_chunks, + }; + mock_senders.start_to_send(); + + // Poll receivers. + let mut actual = Vec::new(); + let mut mock_receivers = MockReceivers { + rx_group, + cur_rx_idx: 0, + }; + while let Some(datum) = mock_receivers.next().await { + actual.push(datum); + } + + assert_eq!(actual, expected); + } +}