@@ -381,7 +381,7 @@ mod tests {
381
381
use std:: fmt:: Formatter ;
382
382
use std:: pin:: Pin ;
383
383
use std:: sync:: Mutex ;
384
- use std:: task:: { Context , Poll } ;
384
+ use std:: task:: { ready , Context , Poll , Waker } ;
385
385
use std:: time:: Duration ;
386
386
387
387
use super :: * ;
@@ -1285,13 +1285,45 @@ mod tests {
1285
1285
"# ) ;
1286
1286
}
1287
1287
1288
+ #[ derive( Debug ) ]
1289
+ struct Congestion {
1290
+ congestion_cleared : Mutex < Option < Vec < Waker > > > ,
1291
+ }
1292
+
1293
+ impl Congestion {
1294
+ fn new ( ) -> Self {
1295
+ Congestion {
1296
+ congestion_cleared : Mutex :: new ( Some ( vec ! [ ] ) ) ,
1297
+ }
1298
+ }
1299
+
1300
+ fn clear_congestion ( & self ) {
1301
+ let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1302
+ if let Some ( wakers) = & mut * cleared {
1303
+ wakers. iter ( ) . for_each ( |w| w. wake_by_ref ( ) ) ;
1304
+ * cleared = None ;
1305
+ }
1306
+ }
1307
+
1308
+ fn check_congested ( & self , cx : & mut Context < ' _ > ) -> Poll < ( ) > {
1309
+ let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1310
+ match & mut * cleared {
1311
+ None => Poll :: Ready ( ( ) ) ,
1312
+ Some ( wakers) => {
1313
+ wakers. push ( cx. waker ( ) . clone ( ) ) ;
1314
+ Poll :: Pending
1315
+ }
1316
+ }
1317
+ }
1318
+ }
1319
+
1288
1320
/// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st
1289
1321
/// partition is exhausted from the start, and if it is polled more than one, it panics.
1290
1322
#[ derive( Debug , Clone ) ]
1291
1323
struct CongestedExec {
1292
1324
schema : Schema ,
1293
1325
cache : PlanProperties ,
1294
- congestion_cleared : Arc < Mutex < bool > > ,
1326
+ congestion : Arc < Congestion > ,
1295
1327
}
1296
1328
1297
1329
impl CongestedExec {
@@ -1346,7 +1378,7 @@ mod tests {
1346
1378
Ok ( Box :: pin ( CongestedStream {
1347
1379
schema : Arc :: new ( self . schema . clone ( ) ) ,
1348
1380
none_polled_once : false ,
1349
- congestion_cleared : Arc :: clone ( & self . congestion_cleared ) ,
1381
+ congestion : Arc :: clone ( & self . congestion ) ,
1350
1382
partition,
1351
1383
} ) )
1352
1384
}
@@ -1373,15 +1405,15 @@ mod tests {
1373
1405
pub struct CongestedStream {
1374
1406
schema : SchemaRef ,
1375
1407
none_polled_once : bool ,
1376
- congestion_cleared : Arc < Mutex < bool > > ,
1408
+ congestion : Arc < Congestion > ,
1377
1409
partition : usize ,
1378
1410
}
1379
1411
1380
1412
impl Stream for CongestedStream {
1381
1413
type Item = Result < RecordBatch > ;
1382
1414
fn poll_next (
1383
1415
mut self : Pin < & mut Self > ,
1384
- _cx : & mut Context < ' _ > ,
1416
+ cx : & mut Context < ' _ > ,
1385
1417
) -> Poll < Option < Self :: Item > > {
1386
1418
match self . partition {
1387
1419
0 => {
@@ -1393,16 +1425,11 @@ mod tests {
1393
1425
}
1394
1426
}
1395
1427
1 => {
1396
- let cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1397
- if * cleared {
1398
- Poll :: Ready ( None )
1399
- } else {
1400
- Poll :: Pending
1401
- }
1428
+ ready ! ( self . congestion. check_congested( cx) ) ;
1429
+ Poll :: Ready ( None )
1402
1430
}
1403
1431
2 => {
1404
- let mut cleared = self . congestion_cleared . lock ( ) . unwrap ( ) ;
1405
- * cleared = true ;
1432
+ self . congestion . clear_congestion ( ) ;
1406
1433
Poll :: Ready ( None )
1407
1434
}
1408
1435
_ => unreachable ! ( ) ,
@@ -1423,7 +1450,7 @@ mod tests {
1423
1450
let source = CongestedExec {
1424
1451
schema : schema. clone ( ) ,
1425
1452
cache : CongestedExec :: compute_properties ( Arc :: new ( schema. clone ( ) ) ) ,
1426
- congestion_cleared : Arc :: new ( Mutex :: new ( false ) ) ,
1453
+ congestion : Arc :: new ( Congestion :: new ( ) ) ,
1427
1454
} ;
1428
1455
let spm = SortPreservingMergeExec :: new (
1429
1456
[ PhysicalSortExpr :: new_default ( Arc :: new ( Column :: new (
0 commit comments