@@ -47,6 +47,21 @@ impl ProjectionPushDown {
4747        mapping :  & mut  HashMap < usize ,  usize > , 
4848    )  -> Result < Rc < LogicalPlan > >  { 
4949        match  plan { 
50+             LogicalPlan :: Selection  {  expr,  input }  => { 
51+                 // collect all columns referenced by filter expression 
52+                 self . collect_expr ( expr,  accum) ; 
53+ 
54+                 // push projection down 
55+                 let  input = self . optimize_plan ( & input,  accum,  mapping) ?; 
56+ 
57+                 // rewrite filter expression to use new column indexes 
58+                 let  new_expr = self . rewrite_expr ( expr,  mapping) ; 
59+ 
60+                 Ok ( Rc :: new ( LogicalPlan :: Selection  { 
61+                     expr :  new_expr, 
62+                     input, 
63+                 } ) ) 
64+             } 
5065            LogicalPlan :: Aggregate  { 
5166                input, 
5267                group_expr, 
@@ -88,6 +103,9 @@ impl ProjectionPushDown {
88103                let  mut  projection:  Vec < usize >  = Vec :: with_capacity ( accum. len ( ) ) ; 
89104                accum. iter ( ) . for_each ( |i| projection. push ( * i) ) ; 
90105
106+                 // sort the projection otherwise we get non-deterministic behavior 
107+                 projection. sort ( ) ; 
108+ 
91109                // now that the table scan is returning a different schema we need to create a 
92110                // mapping from the original column index to the new column index so that we 
93111                // can rewrite expressions as we walk back up the tree 
@@ -184,4 +202,88 @@ mod tests {
184202
185203        assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   TableScan: test projection=Some([1])" ) ; 
186204    } 
205+ 
206+     #[ test]  
207+     fn  aggregate_group_by ( )  { 
208+         let  schema = Schema :: new ( vec ! [ 
209+             Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
210+             Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
211+             Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
212+         ] ) ; 
213+ 
214+         // create unoptimized plan for SELECT MAX(b) FROM default.test 
215+ 
216+         let  table_scan = TableScan  { 
217+             schema_name :  "default" . to_string ( ) , 
218+             table_name :  "test" . to_string ( ) , 
219+             schema :  Arc :: new ( schema) , 
220+             projection :  None , 
221+         } ; 
222+ 
223+         let  aggregate = Aggregate  { 
224+             group_expr :  vec ! [ Column ( 2 ) ] , 
225+             aggr_expr :  vec ! [ Column ( 1 ) ] , 
226+             schema :  Arc :: new ( Schema :: new ( vec ! [ 
227+                 Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
228+                 Field :: new( "MAX(b)" ,  DataType :: UInt32 ,  false ) , 
229+             ] ) ) , 
230+             input :  Rc :: new ( table_scan) , 
231+         } ; 
232+ 
233+         // run optimizer rule 
234+ 
235+         let  rule:  Rc < RefCell < OptimizerRule > >  =
236+             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
237+ 
238+         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
239+ 
240+         let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
241+ 
242+         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n   TableScan: test projection=Some([1, 2])" ) ; 
243+     } 
244+ 
245+     #[ test]  
246+     fn  aggregate_no_group_by_with_selection ( )  { 
247+         let  schema = Schema :: new ( vec ! [ 
248+             Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
249+             Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
250+             Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
251+         ] ) ; 
252+ 
253+         // create unoptimized plan for SELECT MAX(b) FROM default.test 
254+ 
255+         let  table_scan = TableScan  { 
256+             schema_name :  "default" . to_string ( ) , 
257+             table_name :  "test" . to_string ( ) , 
258+             schema :  Arc :: new ( schema) , 
259+             projection :  None , 
260+         } ; 
261+ 
262+         let  selection = Selection  { 
263+             expr :  Column ( 2 ) , 
264+             input :  Rc :: new ( table_scan) , 
265+         } ; 
266+ 
267+         let  aggregate = Aggregate  { 
268+             group_expr :  vec ! [ ] , 
269+             aggr_expr :  vec ! [ Column ( 1 ) ] , 
270+             schema :  Arc :: new ( Schema :: new ( vec ! [ Field :: new( 
271+                 "MAX(b)" , 
272+                 DataType :: UInt32 , 
273+                 false , 
274+             ) ] ) ) , 
275+             input :  Rc :: new ( selection) , 
276+         } ; 
277+ 
278+         // run optimizer rule 
279+ 
280+         let  rule:  Rc < RefCell < OptimizerRule > >  =
281+             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
282+ 
283+         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
284+ 
285+         let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
286+ 
287+         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   Selection: #1\n     TableScan: test projection=Some([1, 2])" ) ; 
288+     } 
187289} 
0 commit comments