2121use std:: any:: Any ;
2222use std:: pin:: Pin ;
2323use std:: sync:: Arc ;
24- use std:: task:: { Context , Poll } ;
24+ use std:: task:: { ready , Context , Poll } ;
2525
2626use super :: metrics:: { BaselineMetrics , ExecutionPlanMetricsSet , MetricsSet } ;
2727use super :: { DisplayAs , ExecutionPlanProperties , PlanProperties , Statistics } ;
@@ -146,10 +146,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
146146 ) -> Result < SendableRecordBatchStream > {
147147 Ok ( Box :: pin ( CoalesceBatchesStream {
148148 input : self . input . execute ( partition, context) ?,
149- schema : self . input . schema ( ) ,
150- target_batch_size : self . target_batch_size ,
151- buffer : Vec :: new ( ) ,
152- buffered_rows : 0 ,
149+ coalescer : BatchCoalescer :: new ( self . input . schema ( ) , self . target_batch_size ) ,
153150 is_closed : false ,
154151 baseline_metrics : BaselineMetrics :: new ( & self . metrics , partition) ,
155152 } ) )
@@ -167,14 +164,8 @@ impl ExecutionPlan for CoalesceBatchesExec {
167164struct CoalesceBatchesStream {
168165 /// The input plan
169166 input : SendableRecordBatchStream ,
170- /// The input schema
171- schema : SchemaRef ,
172- /// Minimum number of rows for coalesces batches
173- target_batch_size : usize ,
174- /// Buffered batches
175- buffer : Vec < RecordBatch > ,
176- /// Buffered row count
177- buffered_rows : usize ,
167+ /// Buffer for combining batches
168+ coalescer : BatchCoalescer ,
178169 /// Whether the stream has finished returning all of its data or not
179170 is_closed : bool ,
180171 /// Execution metrics
@@ -213,66 +204,35 @@ impl CoalesceBatchesStream {
213204 let input_batch = self . input . poll_next_unpin ( cx) ;
214205 // records time on drop
215206 let _timer = cloned_time. timer ( ) ;
216- match input_batch {
217- Poll :: Ready ( x) => match x {
218- Some ( Ok ( batch) ) => {
219- if batch. num_rows ( ) >= self . target_batch_size
220- && self . buffer . is_empty ( )
221- {
222- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
223- } else if batch. num_rows ( ) == 0 {
224- // discard empty batches
225- } else {
226- // add to the buffered batches
227- self . buffered_rows += batch. num_rows ( ) ;
228- self . buffer . push ( batch) ;
229- // check to see if we have enough batches yet
230- if self . buffered_rows >= self . target_batch_size {
231- // combine the batches and return
232- let batch = concat_batches (
233- & self . schema ,
234- & self . buffer ,
235- self . buffered_rows ,
236- ) ?;
237- // reset buffer state
238- self . buffer . clear ( ) ;
239- self . buffered_rows = 0 ;
240- // return batch
241- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
242- }
243- }
244- }
245- None => {
246- self . is_closed = true ;
247- // we have reached the end of the input stream but there could still
248- // be buffered batches
249- if self . buffer . is_empty ( ) {
250- return Poll :: Ready ( None ) ;
251- } else {
252- // combine the batches and return
253- let batch = concat_batches (
254- & self . schema ,
255- & self . buffer ,
256- self . buffered_rows ,
257- ) ?;
258- // reset buffer state
259- self . buffer . clear ( ) ;
260- self . buffered_rows = 0 ;
261- // return batch
262- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
263- }
207+ match ready ! ( input_batch) {
208+ Some ( result) => {
209+ let Ok ( input_batch) = result else {
210+ return Poll :: Ready ( Some ( result) ) ; // pass back error
211+ } ;
212+ // Buffer the batch and either get more input if not enough
213+ // rows yet or output
214+ match self . coalescer . push_batch ( input_batch) {
215+ Ok ( None ) => continue ,
216+ res => return Poll :: Ready ( res. transpose ( ) ) ,
264217 }
265- other => return Poll :: Ready ( other) ,
266- } ,
267- Poll :: Pending => return Poll :: Pending ,
218+ }
219+ None => {
220+ self . is_closed = true ;
221+ // we have reached the end of the input stream but there could still
222+ // be buffered batches
223+ return match self . coalescer . finish ( ) {
224+ Ok ( None ) => Poll :: Ready ( None ) ,
225+ res => Poll :: Ready ( res. transpose ( ) ) ,
226+ } ;
227+ }
268228 }
269229 }
270230 }
271231}
272232
273233impl RecordBatchStream for CoalesceBatchesStream {
274234 fn schema ( & self ) -> SchemaRef {
275- Arc :: clone ( & self . schema )
235+ self . coalescer . schema ( )
276236 }
277237}
278238
@@ -290,26 +250,106 @@ pub fn concat_batches(
290250 arrow:: compute:: concat_batches ( schema, batches)
291251}
292252
253+ /// Concatenating multiple record batches into larger batches
254+ ///
255+ /// TODO ASCII ART
256+ ///
257+ /// Notes:
258+ ///
259+ /// 1. The output is exactly the same order as the input rows
260+ ///
261+ /// 2. The output is a sequence of batches, with all but the last being at least
262+ /// `target_batch_size` rows.
263+ ///
264+ /// 3. Eventually this may also be able to handle other optimizations such as a
265+ /// combined filter/coalesce operation.
266+ #[ derive( Debug ) ]
267+ struct BatchCoalescer {
268+ /// The input schema
269+ schema : SchemaRef ,
270+ /// Minimum number of rows for coalesces batches
271+ target_batch_size : usize ,
272+ /// Buffered batches
273+ buffer : Vec < RecordBatch > ,
274+ /// Buffered row count
275+ buffered_rows : usize ,
276+ }
277+
278+ impl BatchCoalescer {
279+ /// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows
280+ fn new ( schema : SchemaRef , target_batch_size : usize ) -> Self {
281+ Self {
282+ schema,
283+ target_batch_size,
284+ buffer : vec ! [ ] ,
285+ buffered_rows : 0 ,
286+ }
287+ }
288+
289+ /// Return the schema of the output batches
290+ fn schema ( & self ) -> SchemaRef {
291+ Arc :: clone ( & self . schema )
292+ }
293+
294+ /// Add a batch to the coalescer, returning a batch if the target batch size is reached
295+ fn push_batch ( & mut self , batch : RecordBatch ) -> Result < Option < RecordBatch > > {
296+ if batch. num_rows ( ) >= self . target_batch_size && self . buffer . is_empty ( ) {
297+ return Ok ( Some ( batch) ) ;
298+ }
299+ // discard empty batches
300+ if batch. num_rows ( ) == 0 {
301+ return Ok ( None ) ;
302+ }
303+ // add to the buffered batches
304+ self . buffered_rows += batch. num_rows ( ) ;
305+ self . buffer . push ( batch) ;
306+ // check to see if we have enough batches yet
307+ let batch = if self . buffered_rows >= self . target_batch_size {
308+ // combine the batches and return
309+ let batch = concat_batches ( & self . schema , & self . buffer , self . buffered_rows ) ?;
310+ // reset buffer state
311+ self . buffer . clear ( ) ;
312+ self . buffered_rows = 0 ;
313+ // return batch
314+ Some ( batch)
315+ } else {
316+ None
317+ } ;
318+ Ok ( batch)
319+ }
320+
321+ /// Finish the coalescing process, returning all buffered data as a final,
322+ /// single batch, if any
323+ fn finish ( & mut self ) -> Result < Option < RecordBatch > > {
324+ if self . buffer . is_empty ( ) {
325+ Ok ( None )
326+ } else {
327+ // combine the batches and return
328+ let batch = concat_batches ( & self . schema , & self . buffer , self . buffered_rows ) ?;
329+ // reset buffer state
330+ self . buffer . clear ( ) ;
331+ self . buffered_rows = 0 ;
332+ // return batch
333+ Ok ( Some ( batch) )
334+ }
335+ }
336+ }
337+
293338#[ cfg( test) ]
294339mod tests {
295340 use super :: * ;
296- use crate :: { memory:: MemoryExec , repartition:: RepartitionExec , Partitioning } ;
297-
298341 use arrow:: datatypes:: { DataType , Field , Schema } ;
299342 use arrow_array:: UInt32Array ;
300343
301344 #[ tokio:: test( flavor = "multi_thread" ) ]
302345 async fn test_concat_batches ( ) -> Result < ( ) > {
303- let schema = test_schema ( ) ;
304- let partition = create_vec_batches ( & schema, 10 ) ;
305- let partitions = vec ! [ partition] ;
306-
307- let output_partitions = coalesce_batches ( & schema, partitions, 21 ) . await ?;
308- assert_eq ! ( 1 , output_partitions. len( ) ) ;
346+ let Scenario { schema, batch } = uint32_scenario ( ) ;
309347
310348 // input is 10 batches x 8 rows (80 rows)
349+ let input = std:: iter:: repeat ( batch) . take ( 10 ) ;
350+
311351 // expected output is batches of at least 20 rows (except for the final batch)
312- let batches = & output_partitions [ 0 ] ;
352+ let batches = do_coalesce_batches ( & schema , input , 21 ) ;
313353 assert_eq ! ( 4 , batches. len( ) ) ;
314354 assert_eq ! ( 24 , batches[ 0 ] . num_rows( ) ) ;
315355 assert_eq ! ( 24 , batches[ 1 ] . num_rows( ) ) ;
@@ -319,54 +359,43 @@ mod tests {
319359 Ok ( ( ) )
320360 }
321361
322- fn test_schema ( ) -> Arc < Schema > {
323- Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) )
324- }
325-
326- async fn coalesce_batches (
362+ // Coalesce the batches with a BatchCoalescer function with the given input
363+ // and target batch size returning the resulting batches
364+ fn do_coalesce_batches (
327365 schema : & SchemaRef ,
328- input_partitions : Vec < Vec < RecordBatch > > ,
366+ input : impl IntoIterator < Item = RecordBatch > ,
329367 target_batch_size : usize ,
330- ) -> Result < Vec < Vec < RecordBatch > > > {
368+ ) -> Vec < RecordBatch > {
331369 // create physical plan
332- let exec = MemoryExec :: try_new ( & input_partitions, Arc :: clone ( schema) , None ) ?;
333- let exec =
334- RepartitionExec :: try_new ( Arc :: new ( exec) , Partitioning :: RoundRobinBatch ( 1 ) ) ?;
335- let exec: Arc < dyn ExecutionPlan > =
336- Arc :: new ( CoalesceBatchesExec :: new ( Arc :: new ( exec) , target_batch_size) ) ;
337-
338- // execute and collect results
339- let output_partition_count = exec. output_partitioning ( ) . partition_count ( ) ;
340- let mut output_partitions = Vec :: with_capacity ( output_partition_count) ;
341- for i in 0 ..output_partition_count {
342- // execute this *output* partition and collect all batches
343- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
344- let mut stream = exec. execute ( i, Arc :: clone ( & task_ctx) ) ?;
345- let mut batches = vec ! [ ] ;
346- while let Some ( result) = stream. next ( ) . await {
347- batches. push ( result?) ;
348- }
349- output_partitions. push ( batches) ;
370+ let mut coalescer = BatchCoalescer :: new ( Arc :: clone ( schema) , target_batch_size) ;
371+ let mut output_batches: Vec < _ > = input
372+ . into_iter ( )
373+ . filter_map ( |batch| coalescer. push_batch ( batch) . unwrap ( ) )
374+ . collect ( ) ;
375+ if let Some ( batch) = coalescer. finish ( ) . unwrap ( ) {
376+ output_batches. push ( batch) ;
350377 }
351- Ok ( output_partitions )
378+ output_batches
352379 }
353380
354- /// Create vector batches
355- fn create_vec_batches ( schema : & Schema , n : usize ) -> Vec < RecordBatch > {
356- let batch = create_batch ( schema) ;
357- let mut vec = Vec :: with_capacity ( n) ;
358- for _ in 0 ..n {
359- vec. push ( batch. clone ( ) ) ;
360- }
361- vec
381+ /// Test scenario
382+ #[ derive( Debug ) ]
383+ struct Scenario {
384+ schema : Arc < Schema > ,
385+ batch : RecordBatch ,
362386 }
363387
364- /// Create batch
365- fn create_batch ( schema : & Schema ) -> RecordBatch {
366- RecordBatch :: try_new (
367- Arc :: new ( schema. clone ( ) ) ,
388+ /// a batch of 8 rows of UInt32
389+ fn uint32_scenario ( ) -> Scenario {
390+ let schema =
391+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) ) ;
392+
393+ let batch = RecordBatch :: try_new (
394+ Arc :: clone ( & schema) ,
368395 vec ! [ Arc :: new( UInt32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ) ] ,
369396 )
370- . unwrap ( )
397+ . unwrap ( ) ;
398+
399+ Scenario { schema, batch }
371400 }
372401}
0 commit comments