@@ -24,6 +24,7 @@ use std::pin::Pin;
2424use std:: sync:: Arc ;
2525use std:: task:: { Context , Poll } ;
2626
27+ use arrow:: array:: DynComparator ;
2728use arrow:: {
2829 array:: { make_array as make_arrow_array, ArrayRef , MutableArrayData } ,
2930 compute:: SortOptions ,
@@ -35,6 +36,7 @@ use async_trait::async_trait;
3536use futures:: channel:: mpsc;
3637use futures:: stream:: FusedStream ;
3738use futures:: { Stream , StreamExt } ;
39+ use hashbrown:: HashMap ;
3840
3941use crate :: error:: { DataFusionError , Result } ;
4042use crate :: physical_plan:: {
@@ -176,34 +178,60 @@ impl ExecutionPlan for SortPreservingMergeExec {
176178 }
177179}
178180
179- /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when
180- /// evaluated on the `RecordBatch` yield the sort keys.
181+ /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
182+ /// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
181183///
182184/// Additionally it maintains a row cursor that can be advanced through the rows
183185/// of the provided `RecordBatch`
184186///
185- /// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this
186- /// row cursor, with that of another `SortKeyCursor`
187- # [ derive ( Debug , Clone ) ]
187+ /// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
188+ /// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
189+ /// a row comparator for each other cursor that it is compared to.
188190struct SortKeyCursor {
189191 columns : Vec < ArrayRef > ,
190- batch : RecordBatch ,
191192 cur_row : usize ,
192193 num_rows : usize ,
194+
195+ // An index uniquely identifying the record batch scanned by this cursor.
196+ batch_idx : usize ,
197+ batch : RecordBatch ,
198+
199+ // A collection of comparators that compare rows in this cursor's batch to
200+ // the cursors in other batches. Other batches are uniquely identified by
201+ // their batch_idx.
202+ batch_comparators : HashMap < usize , Vec < DynComparator > > ,
203+ }
204+
205+ impl < ' a > std:: fmt:: Debug for SortKeyCursor {
206+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
207+ f. debug_struct ( "SortKeyCursor" )
208+ . field ( "columns" , & self . columns )
209+ . field ( "cur_row" , & self . cur_row )
210+ . field ( "num_rows" , & self . num_rows )
211+ . field ( "batch_idx" , & self . batch_idx )
212+ . field ( "batch" , & self . batch )
213+ . field ( "batch_comparators" , & "<FUNC>" )
214+ . finish ( )
215+ }
193216}
194217
195218impl SortKeyCursor {
196- fn new ( batch : RecordBatch , sort_key : & [ Arc < dyn PhysicalExpr > ] ) -> Result < Self > {
219+ fn new (
220+ batch_idx : usize ,
221+ batch : RecordBatch ,
222+ sort_key : & [ Arc < dyn PhysicalExpr > ] ,
223+ ) -> Result < Self > {
197224 let columns = sort_key
198225 . iter ( )
199226 . map ( |expr| Ok ( expr. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ) )
200227 . collect :: < Result < _ > > ( ) ?;
201-
202228 Ok ( Self {
203229 cur_row : 0 ,
204230 num_rows : batch. num_rows ( ) ,
205231 columns,
206232 batch,
233+ batch_idx,
234+ batch_comparators : HashMap :: new ( ) ,
207235 } )
208236 }
209237
@@ -220,7 +248,7 @@ impl SortKeyCursor {
220248
221249 /// Compares the sort key pointed to by this instance's row cursor with that of another
222250 fn compare (
223- & self ,
251+ & mut self ,
224252 other : & SortKeyCursor ,
225253 options : & [ SortOptions ] ,
226254 ) -> Result < Ordering > {
@@ -246,7 +274,19 @@ impl SortKeyCursor {
246274 . zip ( other. columns . iter ( ) )
247275 . zip ( options. iter ( ) ) ;
248276
249- for ( ( l, r) , sort_options) in zipped {
277+ // Recall or initialise a collection of comparators for comparing
278+ // columnar arrays of this cursor and "other".
279+ let cmp = self
280+ . batch_comparators
281+ . entry ( other. batch_idx )
282+ . or_insert_with ( || Vec :: with_capacity ( other. columns . len ( ) ) ) ;
283+
284+ for ( i, ( ( l, r) , sort_options) ) in zipped. enumerate ( ) {
285+ if i >= cmp. len ( ) {
286+ // initialise comparators as potentially needed
287+ cmp. push ( arrow:: array:: build_compare ( l. as_ref ( ) , r. as_ref ( ) ) ?) ;
288+ }
289+
250290 match ( l. is_valid ( self . cur_row ) , r. is_valid ( other. cur_row ) ) {
251291 ( false , true ) if sort_options. nulls_first => return Ok ( Ordering :: Less ) ,
252292 ( false , true ) => return Ok ( Ordering :: Greater ) ,
@@ -255,15 +295,11 @@ impl SortKeyCursor {
255295 }
256296 ( true , false ) => return Ok ( Ordering :: Less ) ,
257297 ( false , false ) => { }
258- ( true , true ) => {
259- // TODO: Building the predicate each time is sub-optimal
260- let c = arrow:: array:: build_compare ( l. as_ref ( ) , r. as_ref ( ) ) ?;
261- match c ( self . cur_row , other. cur_row ) {
262- Ordering :: Equal => { }
263- o if sort_options. descending => return Ok ( o. reverse ( ) ) ,
264- o => return Ok ( o) ,
265- }
266- }
298+ ( true , true ) => match cmp[ i] ( self . cur_row , other. cur_row ) {
299+ Ordering :: Equal => { }
300+ o if sort_options. descending => return Ok ( o. reverse ( ) ) ,
301+ o => return Ok ( o) ,
302+ } ,
267303 }
268304 }
269305
@@ -304,6 +340,9 @@ struct SortPreservingMergeStream {
304340 target_batch_size : usize ,
305341 /// If the stream has encountered an error
306342 aborted : bool ,
343+
344+ /// An index to uniquely identify the input stream batch
345+ next_batch_index : usize ,
307346}
308347
309348impl SortPreservingMergeStream {
@@ -313,15 +352,21 @@ impl SortPreservingMergeStream {
313352 expressions : & [ PhysicalSortExpr ] ,
314353 target_batch_size : usize ,
315354 ) -> Self {
355+ let cursors = ( 0 ..streams. len ( ) )
356+ . into_iter ( )
357+ . map ( |_| VecDeque :: new ( ) )
358+ . collect ( ) ;
359+
316360 Self {
317361 schema,
318- cursors : vec ! [ Default :: default ( ) ; streams . len ( ) ] ,
362+ cursors,
319363 streams,
320364 column_expressions : expressions. iter ( ) . map ( |x| x. expr . clone ( ) ) . collect ( ) ,
321365 sort_options : expressions. iter ( ) . map ( |x| x. options ) . collect ( ) ,
322366 target_batch_size,
323367 aborted : false ,
324368 in_progress : vec ! [ ] ,
369+ next_batch_index : 0 ,
325370 }
326371 }
327372
@@ -352,12 +397,17 @@ impl SortPreservingMergeStream {
352397 return Poll :: Ready ( Err ( e) ) ;
353398 }
354399 Some ( Ok ( batch) ) => {
355- let cursor = match SortKeyCursor :: new ( batch, & self . column_expressions ) {
400+ let cursor = match SortKeyCursor :: new (
401+ self . next_batch_index , // assign this batch an ID
402+ batch,
403+ & self . column_expressions ,
404+ ) {
356405 Ok ( cursor) => cursor,
357406 Err ( e) => {
358407 return Poll :: Ready ( Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ) ;
359408 }
360409 } ;
410+ self . next_batch_index += 1 ;
361411 self . cursors [ idx] . push_back ( cursor)
362412 }
363413 }
@@ -367,17 +417,17 @@ impl SortPreservingMergeStream {
367417
368418 /// Returns the index of the next stream to pull a row from, or None
369419 /// if all cursors for all streams are exhausted
370- fn next_stream_idx ( & self ) -> Result < Option < usize > > {
371- let mut min_cursor: Option < ( usize , & SortKeyCursor ) > = None ;
372- for ( idx, candidate) in self . cursors . iter ( ) . enumerate ( ) {
373- if let Some ( candidate) = candidate. back ( ) {
420+ fn next_stream_idx ( & mut self ) -> Result < Option < usize > > {
421+ let mut min_cursor: Option < ( usize , & mut SortKeyCursor ) > = None ;
422+ for ( idx, candidate) in self . cursors . iter_mut ( ) . enumerate ( ) {
423+ if let Some ( candidate) = candidate. back_mut ( ) {
374424 if candidate. is_finished ( ) {
375425 continue ;
376426 }
377427
378428 match min_cursor {
379429 None => min_cursor = Some ( ( idx, candidate) ) ,
380- Some ( ( _, min) ) => {
430+ Some ( ( _, ref mut min) ) => {
381431 if min. compare ( candidate, & self . sort_options ) ?
382432 == Ordering :: Greater
383433 {
@@ -599,8 +649,7 @@ mod tests {
599649 let b2 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
600650
601651 _test_merge (
602- b1,
603- b2,
652+ & [ vec ! [ b1] , vec ! [ b2] ] ,
604653 & [
605654 "+----+---+-------------------------------+" ,
606655 "| a | b | c |" ,
@@ -646,8 +695,7 @@ mod tests {
646695 let b2 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
647696
648697 _test_merge (
649- b1,
650- b2,
698+ & [ vec ! [ b1] , vec ! [ b2] ] ,
651699 & [
652700 "+-----+---+-------------------------------+" ,
653701 "| a | b | c |" ,
@@ -693,8 +741,7 @@ mod tests {
693741 let b2 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
694742
695743 _test_merge (
696- b1,
697- b2,
744+ & [ vec ! [ b1] , vec ! [ b2] ] ,
698745 & [
699746 "+----+---+-------------------------------+" ,
700747 "| a | b | c |" ,
@@ -715,8 +762,71 @@ mod tests {
715762 . await ;
716763 }
717764
718- async fn _test_merge ( b1 : RecordBatch , b2 : RecordBatch , exp : & [ & str ] ) {
719- let schema = b1. schema ( ) ;
765+ #[ tokio:: test]
766+ async fn test_merge_three_partitions ( ) {
767+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 7 , 9 , 3 ] ) ) ;
768+ let b: ArrayRef = Arc :: new ( StringArray :: from_iter ( vec ! [
769+ Some ( "a" ) ,
770+ Some ( "b" ) ,
771+ Some ( "c" ) ,
772+ Some ( "d" ) ,
773+ Some ( "f" ) ,
774+ ] ) ) ;
775+ let c: ArrayRef = Arc :: new ( TimestampNanosecondArray :: from ( vec ! [ 8 , 7 , 6 , 5 , 8 ] ) ) ;
776+ let b1 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
777+
778+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 10 , 20 , 70 , 90 , 30 ] ) ) ;
779+ let b: ArrayRef = Arc :: new ( StringArray :: from_iter ( vec ! [
780+ Some ( "e" ) ,
781+ Some ( "g" ) ,
782+ Some ( "h" ) ,
783+ Some ( "i" ) ,
784+ Some ( "j" ) ,
785+ ] ) ) ;
786+ let c: ArrayRef =
787+ Arc :: new ( TimestampNanosecondArray :: from ( vec ! [ 40 , 60 , 20 , 20 , 60 ] ) ) ;
788+ let b2 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
789+
790+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 100 , 200 , 700 , 900 , 300 ] ) ) ;
791+ let b: ArrayRef = Arc :: new ( StringArray :: from_iter ( vec ! [
792+ Some ( "f" ) ,
793+ Some ( "g" ) ,
794+ Some ( "h" ) ,
795+ Some ( "i" ) ,
796+ Some ( "j" ) ,
797+ ] ) ) ;
798+ let c: ArrayRef = Arc :: new ( TimestampNanosecondArray :: from ( vec ! [ 4 , 6 , 2 , 2 , 6 ] ) ) ;
799+ let b3 = RecordBatch :: try_from_iter ( vec ! [ ( "a" , a) , ( "b" , b) , ( "c" , c) ] ) . unwrap ( ) ;
800+
801+ _test_merge (
802+ & [ vec ! [ b1] , vec ! [ b2] , vec ! [ b3] ] ,
803+ & [
804+ "+-----+---+-------------------------------+" ,
805+ "| a | b | c |" ,
806+ "+-----+---+-------------------------------+" ,
807+ "| 1 | a | 1970-01-01 00:00:00.000000008 |" ,
808+ "| 2 | b | 1970-01-01 00:00:00.000000007 |" ,
809+ "| 7 | c | 1970-01-01 00:00:00.000000006 |" ,
810+ "| 9 | d | 1970-01-01 00:00:00.000000005 |" ,
811+ "| 10 | e | 1970-01-01 00:00:00.000000040 |" ,
812+ "| 100 | f | 1970-01-01 00:00:00.000000004 |" ,
813+ "| 3 | f | 1970-01-01 00:00:00.000000008 |" ,
814+ "| 200 | g | 1970-01-01 00:00:00.000000006 |" ,
815+ "| 20 | g | 1970-01-01 00:00:00.000000060 |" ,
816+ "| 700 | h | 1970-01-01 00:00:00.000000002 |" ,
817+ "| 70 | h | 1970-01-01 00:00:00.000000020 |" ,
818+ "| 900 | i | 1970-01-01 00:00:00.000000002 |" ,
819+ "| 90 | i | 1970-01-01 00:00:00.000000020 |" ,
820+ "| 300 | j | 1970-01-01 00:00:00.000000006 |" ,
821+ "| 30 | j | 1970-01-01 00:00:00.000000060 |" ,
822+ "+-----+---+-------------------------------+" ,
823+ ] ,
824+ )
825+ . await ;
826+ }
827+
828+ async fn _test_merge ( partitions : & [ Vec < RecordBatch > ] , exp : & [ & str ] ) {
829+ let schema = partitions[ 0 ] [ 0 ] . schema ( ) ;
720830 let sort = vec ! [
721831 PhysicalSortExpr {
722832 expr: col( "b" , & schema) . unwrap( ) ,
@@ -727,12 +837,10 @@ mod tests {
727837 options: Default :: default ( ) ,
728838 } ,
729839 ] ;
730- let exec = MemoryExec :: try_new ( & [ vec ! [ b1 ] , vec ! [ b2 ] ] , schema, None ) . unwrap ( ) ;
840+ let exec = MemoryExec :: try_new ( partitions , schema, None ) . unwrap ( ) ;
731841 let merge = Arc :: new ( SortPreservingMergeExec :: new ( sort, Arc :: new ( exec) , 1024 ) ) ;
732842
733843 let collected = collect ( merge) . await . unwrap ( ) ;
734- assert_eq ! ( collected. len( ) , 1 ) ;
735-
736844 assert_batches_eq ! ( exp, collected. as_slice( ) ) ;
737845 }
738846
0 commit comments