2020use std:: sync:: Arc ;
2121
2222use arrow:: {
23- array:: { ArrayRef , Int32Array } ,
23+ array:: { as_string_array , ArrayRef , Int32Array , StringArray } ,
2424 compute:: SortOptions ,
2525 record_batch:: RecordBatch ,
2626} ;
@@ -29,6 +29,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
2929use datafusion:: physical_plan:: sorts:: sort:: SortExec ;
3030use datafusion:: physical_plan:: { collect, ExecutionPlan } ;
3131use datafusion:: prelude:: { SessionConfig , SessionContext } ;
32+ use datafusion_common:: cast:: as_int32_array;
3233use datafusion_execution:: memory_pool:: GreedyMemoryPool ;
3334use datafusion_physical_expr:: expressions:: col;
3435use datafusion_physical_expr_common:: sort_expr:: LexOrdering ;
@@ -42,12 +43,17 @@ const KB: usize = 1 << 10;
4243#[ cfg_attr( tarpaulin, ignore) ]
4344async fn test_sort_10k_mem ( ) {
4445 for ( batch_size, should_spill) in [ ( 5 , false ) , ( 20000 , true ) , ( 500000 , true ) ] {
45- SortTest :: new ( )
46+ let ( input , collected ) = SortTest :: new ( )
4647 . with_int32_batches ( batch_size)
48+ . with_sort_columns ( vec ! [ "x" ] )
4749 . with_pool_size ( 10 * KB )
4850 . with_should_spill ( should_spill)
4951 . run ( )
5052 . await ;
53+
54+ let expected = partitions_to_sorted_vec ( & input) ;
55+ let actual = batches_to_vec ( & collected) ;
56+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
5157 }
5258}
5359
@@ -57,29 +63,119 @@ async fn test_sort_100k_mem() {
5763 for ( batch_size, should_spill) in
5864 [ ( 5 , false ) , ( 10000 , false ) , ( 20000 , true ) , ( 1000000 , true ) ]
5965 {
60- SortTest :: new ( )
66+ let ( input , collected ) = SortTest :: new ( )
6167 . with_int32_batches ( batch_size)
68+ . with_sort_columns ( vec ! [ "x" ] )
69+ . with_pool_size ( 100 * KB )
70+ . with_should_spill ( should_spill)
71+ . run ( )
72+ . await ;
73+
74+ let expected = partitions_to_sorted_vec ( & input) ;
75+ let actual = batches_to_vec ( & collected) ;
76+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
77+ }
78+ }
79+
80+ #[ tokio:: test]
81+ #[ cfg_attr( tarpaulin, ignore) ]
82+ async fn test_sort_strings_100k_mem ( ) {
83+ for ( batch_size, should_spill) in
84+ [ ( 5 , false ) , ( 1000 , false ) , ( 10000 , true ) , ( 20000 , true ) ]
85+ {
86+ let ( input, collected) = SortTest :: new ( )
87+ . with_utf8_batches ( batch_size)
88+ . with_sort_columns ( vec ! [ "x" ] )
6289 . with_pool_size ( 100 * KB )
6390 . with_should_spill ( should_spill)
6491 . run ( )
6592 . await ;
93+
94+ let mut input = input
95+ . iter ( )
96+ . flat_map ( |p| p. iter ( ) )
97+ . flat_map ( |b| {
98+ let array = b. column ( 0 ) ;
99+ as_string_array ( array)
100+ . iter ( )
101+ . map ( |s| s. unwrap ( ) . to_string ( ) )
102+ } )
103+ . collect :: < Vec < String > > ( ) ;
104+ input. sort_unstable ( ) ;
105+ let actual = collected
106+ . iter ( )
107+ . flat_map ( |b| {
108+ let array = b. column ( 0 ) ;
109+ as_string_array ( array)
110+ . iter ( )
111+ . map ( |s| s. unwrap ( ) . to_string ( ) )
112+ } )
113+ . collect :: < Vec < String > > ( ) ;
114+ assert_eq ! ( input, actual) ;
115+ }
116+ }
117+
118+ #[ tokio:: test]
119+ #[ cfg_attr( tarpaulin, ignore) ]
120+ async fn test_sort_multi_columns_100k_mem ( ) {
121+ for ( batch_size, should_spill) in
122+ [ ( 5 , false ) , ( 1000 , false ) , ( 10000 , true ) , ( 20000 , true ) ]
123+ {
124+ let ( input, collected) = SortTest :: new ( )
125+ . with_int32_utf8_batches ( batch_size)
126+ . with_sort_columns ( vec ! [ "x" , "y" ] )
127+ . with_pool_size ( 100 * KB )
128+ . with_should_spill ( should_spill)
129+ . run ( )
130+ . await ;
131+
132+ fn record_batch_to_vec ( b : & RecordBatch ) -> Vec < ( i32 , String ) > {
133+ let mut rows: Vec < _ > = Vec :: new ( ) ;
134+ let i32_array = as_int32_array ( b. column ( 0 ) ) . unwrap ( ) ;
135+ let string_array = as_string_array ( b. column ( 1 ) ) ;
136+ for i in 0 ..b. num_rows ( ) {
137+ let str = string_array. value ( i) . to_string ( ) ;
138+ let i32 = i32_array. value ( i) ;
139+ rows. push ( ( i32, str) ) ;
140+ }
141+ rows
142+ }
143+ let mut input = input
144+ . iter ( )
145+ . flat_map ( |p| p. iter ( ) )
146+ . flat_map ( record_batch_to_vec)
147+ . collect :: < Vec < ( i32 , String ) > > ( ) ;
148+ input. sort_unstable ( ) ;
149+ let actual = collected
150+ . iter ( )
151+ . flat_map ( record_batch_to_vec)
152+ . collect :: < Vec < ( i32 , String ) > > ( ) ;
153+ assert_eq ! ( input, actual) ;
66154 }
67155}
68156
69157#[ tokio:: test]
70158async fn test_sort_unlimited_mem ( ) {
71159 for ( batch_size, should_spill) in [ ( 5 , false ) , ( 20000 , false ) , ( 1000000 , false ) ] {
72- SortTest :: new ( )
160+ let ( input , collected ) = SortTest :: new ( )
73161 . with_int32_batches ( batch_size)
162+ . with_sort_columns ( vec ! [ "x" ] )
74163 . with_pool_size ( usize:: MAX )
75164 . with_should_spill ( should_spill)
76165 . run ( )
77166 . await ;
167+
168+ let expected = partitions_to_sorted_vec ( & input) ;
169+ let actual = batches_to_vec ( & collected) ;
170+ assert_eq ! ( expected, actual, "failure in @ batch_size {batch_size:?}" ) ;
78171 }
79172}
173+
80174#[ derive( Debug , Default ) ]
81175struct SortTest {
82176 input : Vec < Vec < RecordBatch > > ,
177+ /// The names of the columns to sort by
178+ sort_columns : Vec < String > ,
83179 /// GreedyMemoryPool size, if specified
84180 pool_size : Option < usize > ,
85181 /// If true, expect the sort to spill
@@ -91,12 +187,29 @@ impl SortTest {
91187 Default :: default ( )
92188 }
93189
190+ fn with_sort_columns ( mut self , sort_columns : Vec < & str > ) -> Self {
191+ self . sort_columns = sort_columns. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
192+ self
193+ }
194+
94195 /// Create batches of int32 values of rows
95196 fn with_int32_batches ( mut self , rows : usize ) -> Self {
96197 self . input = vec ! [ make_staggered_i32_batches( rows) ] ;
97198 self
98199 }
99200
201+ /// Create batches of utf8 values of rows
202+ fn with_utf8_batches ( mut self , rows : usize ) -> Self {
203+ self . input = vec ! [ make_staggered_utf8_batches( rows) ] ;
204+ self
205+ }
206+
207+ /// Create batches of int32 and utf8 values of rows
208+ fn with_int32_utf8_batches ( mut self , rows : usize ) -> Self {
209+ self . input = vec ! [ make_staggered_i32_utf8_batches( rows) ] ;
210+ self
211+ }
212+
100213 /// specify that this test should use a memory pool of the specified size
101214 fn with_pool_size ( mut self , pool_size : usize ) -> Self {
102215 self . pool_size = Some ( pool_size) ;
@@ -110,7 +223,7 @@ impl SortTest {
110223
111224 /// Sort the input using SortExec and ensure the results are
112225 /// correct according to `Vec::sort` both with and without spilling
113- async fn run ( & self ) {
226+ async fn run ( & self ) -> ( Vec < Vec < RecordBatch > > , Vec < RecordBatch > ) {
114227 let input = self . input . clone ( ) ;
115228 let first_batch = input
116229 . iter ( )
@@ -119,16 +232,21 @@ impl SortTest {
119232 . expect ( "at least one batch" ) ;
120233 let schema = first_batch. schema ( ) ;
121234
122- let sort = LexOrdering :: new ( vec ! [ PhysicalSortExpr {
123- expr: col( "x" , & schema) . unwrap( ) ,
124- options: SortOptions {
125- descending: false ,
126- nulls_first: true ,
127- } ,
128- } ] ) ;
235+ let sort_ordering = LexOrdering :: new (
236+ self . sort_columns
237+ . iter ( )
238+ . map ( |c| PhysicalSortExpr {
239+ expr : col ( c, & schema) . unwrap ( ) ,
240+ options : SortOptions {
241+ descending : false ,
242+ nulls_first : true ,
243+ } ,
244+ } )
245+ . collect ( ) ,
246+ ) ;
129247
130248 let exec = MemorySourceConfig :: try_new_exec ( & input, schema, None ) . unwrap ( ) ;
131- let sort = Arc :: new ( SortExec :: new ( sort , exec) ) ;
249+ let sort = Arc :: new ( SortExec :: new ( sort_ordering , exec) ) ;
132250
133251 let session_config = SessionConfig :: new ( ) ;
134252 let session_ctx = if let Some ( pool_size) = self . pool_size {
@@ -153,9 +271,6 @@ impl SortTest {
153271 let task_ctx = session_ctx. task_ctx ( ) ;
154272 let collected = collect ( sort. clone ( ) , task_ctx) . await . unwrap ( ) ;
155273
156- let expected = partitions_to_sorted_vec ( & input) ;
157- let actual = batches_to_vec ( & collected) ;
158-
159274 if self . should_spill {
160275 assert_ne ! (
161276 sort. metrics( ) . unwrap( ) . spill_count( ) . unwrap( ) ,
@@ -175,7 +290,8 @@ impl SortTest {
175290 0 ,
176291 "The sort should have returned all memory used back to the memory pool"
177292 ) ;
178- assert_eq ! ( expected, actual, "failure in @ pool_size {self:?}" ) ;
293+
294+ ( input, collected)
179295 }
180296}
181297
@@ -203,3 +319,63 @@ fn make_staggered_i32_batches(len: usize) -> Vec<RecordBatch> {
203319 }
204320 batches
205321}
322+
323+ /// Return randomly sized record batches in a field named 'x' of type `Utf8`
324+ /// with randomized content
325+ fn make_staggered_utf8_batches ( len : usize ) -> Vec < RecordBatch > {
326+ let mut rng = rand:: thread_rng ( ) ;
327+ let max_batch = 1024 ;
328+
329+ let mut batches = vec ! [ ] ;
330+ let mut remaining = len;
331+ while remaining != 0 {
332+ let to_read = rng. gen_range ( 0 ..=remaining. min ( max_batch) ) ;
333+ remaining -= to_read;
334+
335+ batches. push (
336+ RecordBatch :: try_from_iter ( vec ! [ (
337+ "x" ,
338+ Arc :: new( StringArray :: from_iter_values(
339+ ( 0 ..to_read) . map( |_| format!( "test_string_{}" , rng. gen :: <u32 >( ) ) ) ,
340+ ) ) as ArrayRef ,
341+ ) ] )
342+ . unwrap ( ) ,
343+ )
344+ }
345+ batches
346+ }
347+
348+ /// Return randomly sized record batches in a field named 'x' of type `Int32`
349+ /// with randomized i32 content and a field named 'y' of type `Utf8`
350+ /// with randomized content
351+ fn make_staggered_i32_utf8_batches ( len : usize ) -> Vec < RecordBatch > {
352+ let mut rng = rand:: thread_rng ( ) ;
353+ let max_batch = 1024 ;
354+
355+ let mut batches = vec ! [ ] ;
356+ let mut remaining = len;
357+ while remaining != 0 {
358+ let to_read = rng. gen_range ( 0 ..=remaining. min ( max_batch) ) ;
359+ remaining -= to_read;
360+
361+ batches. push (
362+ RecordBatch :: try_from_iter ( vec ! [
363+ (
364+ "x" ,
365+ Arc :: new( Int32Array :: from_iter_values(
366+ ( 0 ..to_read) . map( |_| rng. gen ( ) ) ,
367+ ) ) as ArrayRef ,
368+ ) ,
369+ (
370+ "y" ,
371+ Arc :: new( StringArray :: from_iter_values(
372+ ( 0 ..to_read) . map( |_| format!( "test_string_{}" , rng. gen :: <u32 >( ) ) ) ,
373+ ) ) as ArrayRef ,
374+ ) ,
375+ ] )
376+ . unwrap ( ) ,
377+ )
378+ }
379+
380+ batches
381+ }
0 commit comments