@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030use super :: {
3131 DisplayAs , ExecutionPlanProperties , RecordBatchStream , SendableRecordBatchStream ,
3232} ;
33+ use crate :: coalesce:: { LimitedBatchCoalescer , PushBatchStatus } ;
3334use crate :: execution_plan:: { CardinalityEffect , EvaluationType , SchedulingType } ;
3435use crate :: hash_utils:: create_hashes;
3536use crate :: metrics:: { BaselineMetrics , SpillMetrics } ;
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933 spill_stream,
933934 1 , // Each receiver handles one input partition
934935 BaselineMetrics :: new ( & metrics, partition) ,
936+ context. session_config ( ) . batch_size ( ) ,
935937 ) ) as SendableRecordBatchStream
936938 } )
937939 . collect :: < Vec < _ > > ( ) ;
@@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec {
970972 spill_stream,
971973 num_input_partitions,
972974 BaselineMetrics :: new ( & metrics, partition) ,
975+ context. session_config ( ) . batch_size ( ) ,
973976 ) ) as SendableRecordBatchStream )
974977 }
975978 } )
@@ -1427,9 +1430,12 @@ struct PerPartitionStream {
14271430
14281431 /// Execution metrics
14291432 baseline_metrics : BaselineMetrics ,
1433+
1434+ coalescer : LimitedBatchCoalescer ,
14301435}
14311436
14321437impl PerPartitionStream {
1438+ #[ allow( clippy:: too_many_arguments) ]
14331439 fn new (
14341440 schema : SchemaRef ,
14351441 receiver : DistributionReceiver < MaybeBatch > ,
@@ -1438,16 +1444,18 @@ impl PerPartitionStream {
14381444 spill_stream : SendableRecordBatchStream ,
14391445 num_input_partitions : usize ,
14401446 baseline_metrics : BaselineMetrics ,
1447+ batch_size : usize ,
14411448 ) -> Self {
14421449 Self {
1443- schema,
1450+ schema : Arc :: clone ( & schema ) ,
14441451 receiver,
14451452 _drop_helper : drop_helper,
14461453 reservation,
14471454 spill_stream,
14481455 state : StreamState :: ReadingMemory ,
14491456 remaining_partitions : num_input_partitions,
14501457 baseline_metrics,
1458+ coalescer : LimitedBatchCoalescer :: new ( schema, batch_size, None ) ,
14511459 }
14521460 }
14531461
@@ -1540,7 +1548,49 @@ impl Stream for PerPartitionStream {
15401548 mut self : Pin < & mut Self > ,
15411549 cx : & mut Context < ' _ > ,
15421550 ) -> Poll < Option < Self :: Item > > {
1543- let poll = self . poll_next_inner ( cx) ;
1551+ let cloned_time = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
1552+ let mut completed = false ;
1553+
1554+ let poll;
1555+ loop {
1556+ if let Some ( batch) = self . coalescer . next_completed_batch ( ) {
1557+ poll = Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1558+ break ;
1559+ }
1560+ if completed {
1561+ poll = Poll :: Ready ( None ) ;
1562+ break ;
1563+ }
1564+ let inner_poll = self . poll_next_inner ( cx) ;
1565+ let _timer = cloned_time. timer ( ) ;
1566+
1567+ match inner_poll {
1568+ Poll :: Pending => {
1569+ poll = Poll :: Pending ;
1570+ break ;
1571+ }
1572+ Poll :: Ready ( None ) => {
1573+ completed = true ;
1574+ self . coalescer . finish ( ) ?;
1575+ }
1576+ Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1577+ match self . coalescer . push_batch ( batch) ? {
1578+ PushBatchStatus :: Continue => {
1579+ // Keep pushing more batches
1580+ }
1581+ PushBatchStatus :: LimitReached => {
1582+ // limit was reached, so stop early
1583+ completed = true ;
1584+ self . coalescer . finish ( ) ?;
1585+ }
1586+ }
1587+ }
1588+ Poll :: Ready ( Some ( err) ) => {
1589+ poll = Poll :: Ready ( Some ( err) ) ;
1590+ break ;
1591+ }
1592+ }
1593+ }
15441594 self . baseline_metrics . record_poll ( poll)
15451595 }
15461596}
@@ -1575,9 +1625,9 @@ mod tests {
15751625 use datafusion_common:: exec_err;
15761626 use datafusion_common:: test_util:: batches_to_sort_string;
15771627 use datafusion_common_runtime:: JoinSet ;
1628+ use datafusion_execution:: config:: SessionConfig ;
15781629 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
15791630 use insta:: assert_snapshot;
1580- use itertools:: Itertools ;
15811631
15821632 #[ tokio:: test]
15831633 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1671,7 +1721,10 @@ mod tests {
16711721 input_partitions : Vec < Vec < RecordBatch > > ,
16721722 partitioning : Partitioning ,
16731723 ) -> Result < Vec < Vec < RecordBatch > > > {
1674- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
1724+ let task_ctx = Arc :: new (
1725+ TaskContext :: default ( )
1726+ . with_session_config ( SessionConfig :: new ( ) . with_batch_size ( 8 ) ) ,
1727+ ) ;
16751728 // create physical plan
16761729 let exec =
16771730 TestMemoryExec :: try_new_exec ( & input_partitions, Arc :: clone ( schema) , None ) ?;
@@ -1950,14 +2003,13 @@ mod tests {
19502003 } ) ;
19512004 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
19522005
1953- fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1954- batch
1955- . into_iter ( )
1956- . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1957- . collect ( )
1958- }
1959-
1960- assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
2006+ let items_vec_with_drop = str_batches_to_vec ( & batches_with_drop) ;
2007+ let items_set_with_drop: HashSet < & str > =
2008+ items_vec_with_drop. iter ( ) . copied ( ) . collect ( ) ;
2009+ assert_eq ! (
2010+ items_set_with_drop. symmetric_difference( & items_set) . count( ) ,
2011+ 0
2012+ ) ;
19612013 }
19622014
19632015 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
@@ -2396,6 +2448,7 @@ mod test {
23962448 use arrow:: compute:: SortOptions ;
23972449 use arrow:: datatypes:: { DataType , Field , Schema } ;
23982450 use datafusion_common:: assert_batches_eq;
2451+ use datafusion_execution:: config:: SessionConfig ;
23992452
24002453 use super :: * ;
24012454 use crate :: test:: TestMemoryExec ;
@@ -2507,8 +2560,10 @@ mod test {
25072560 let runtime = RuntimeEnvBuilder :: default ( )
25082561 . with_memory_limit ( 64 , 1.0 )
25092562 . build_arc ( ) ?;
2510-
2511- let task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
2563+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 4 ) ;
2564+ let task_ctx = TaskContext :: default ( )
2565+ . with_runtime ( runtime)
2566+ . with_session_config ( session_config) ;
25122567 let task_ctx = Arc :: new ( task_ctx) ;
25132568
25142569 // Create physical plan with order preservation
0 commit comments