Skip to content

Commit 8503407

Browse files
committed
bug: make CongestedStream a correct Stream implementation
1 parent f7405ee commit 8503407

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

datafusion/physical-plan/src/sorts/sort_preserving_merge.rs

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ mod tests {
381381
use std::fmt::Formatter;
382382
use std::pin::Pin;
383383
use std::sync::Mutex;
384-
use std::task::{Context, Poll};
384+
use std::task::{ready, Context, Poll, Waker};
385385
use std::time::Duration;
386386

387387
use super::*;
@@ -1285,13 +1285,45 @@ mod tests {
12851285
"#);
12861286
}
12871287

1288+
#[derive(Debug)]
1289+
struct Congestion {
1290+
congestion_cleared: Mutex<Option<Vec<Waker>>>,
1291+
}
1292+
1293+
impl Congestion {
1294+
fn new() -> Self {
1295+
Congestion {
1296+
congestion_cleared: Mutex::new(Some(vec![])),
1297+
}
1298+
}
1299+
1300+
fn clear_congestion(&self) {
1301+
let mut cleared = self.congestion_cleared.lock().unwrap();
1302+
if let Some(wakers) = &mut *cleared {
1303+
wakers.iter().for_each(|w| w.wake_by_ref());
1304+
*cleared = None;
1305+
}
1306+
}
1307+
1308+
fn check_congested(&self, cx: &mut Context<'_>) -> Poll<()> {
1309+
let mut cleared = self.congestion_cleared.lock().unwrap();
1310+
match &mut *cleared {
1311+
None => Poll::Ready(()),
1312+
Some(wakers) => {
1313+
wakers.push(cx.waker().clone());
1314+
Poll::Pending
1315+
}
1316+
}
1317+
}
1318+
}
1319+
12881320
/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
12891321
/// partition is exhausted from the start, and if it is polled more than one, it panics.
12901322
#[derive(Debug, Clone)]
12911323
struct CongestedExec {
12921324
schema: Schema,
12931325
cache: PlanProperties,
1294-
congestion_cleared: Arc<Mutex<bool>>,
1326+
congestion: Arc<Congestion>,
12951327
}
12961328

12971329
impl CongestedExec {
@@ -1346,7 +1378,7 @@ mod tests {
13461378
Ok(Box::pin(CongestedStream {
13471379
schema: Arc::new(self.schema.clone()),
13481380
none_polled_once: false,
1349-
congestion_cleared: Arc::clone(&self.congestion_cleared),
1381+
congestion: Arc::clone(&self.congestion),
13501382
partition,
13511383
}))
13521384
}
@@ -1373,15 +1405,15 @@ mod tests {
13731405
pub struct CongestedStream {
13741406
schema: SchemaRef,
13751407
none_polled_once: bool,
1376-
congestion_cleared: Arc<Mutex<bool>>,
1408+
congestion: Arc<Congestion>,
13771409
partition: usize,
13781410
}
13791411

13801412
impl Stream for CongestedStream {
13811413
type Item = Result<RecordBatch>;
13821414
fn poll_next(
13831415
mut self: Pin<&mut Self>,
1384-
_cx: &mut Context<'_>,
1416+
cx: &mut Context<'_>,
13851417
) -> Poll<Option<Self::Item>> {
13861418
match self.partition {
13871419
0 => {
@@ -1393,16 +1425,11 @@ mod tests {
13931425
}
13941426
}
13951427
1 => {
1396-
let cleared = self.congestion_cleared.lock().unwrap();
1397-
if *cleared {
1398-
Poll::Ready(None)
1399-
} else {
1400-
Poll::Pending
1401-
}
1428+
ready!(self.congestion.check_congested(cx));
1429+
Poll::Ready(None)
14021430
}
14031431
2 => {
1404-
let mut cleared = self.congestion_cleared.lock().unwrap();
1405-
*cleared = true;
1432+
self.congestion.clear_congestion();
14061433
Poll::Ready(None)
14071434
}
14081435
_ => unreachable!(),
@@ -1423,7 +1450,7 @@ mod tests {
14231450
let source = CongestedExec {
14241451
schema: schema.clone(),
14251452
cache: CongestedExec::compute_properties(Arc::new(schema.clone())),
1426-
congestion_cleared: Arc::new(Mutex::new(false)),
1453+
congestion: Arc::new(Congestion::new()),
14271454
};
14281455
let spm = SortPreservingMergeExec::new(
14291456
[PhysicalSortExpr::new_default(Arc::new(Column::new(

0 commit comments

Comments
 (0)