@@ -29,7 +29,7 @@ pub struct ProjectionPushDown {}
2929
3030impl  OptimizerRule  for  ProjectionPushDown  { 
3131    fn  optimize ( & mut  self ,  plan :  & LogicalPlan )  -> Result < Rc < LogicalPlan > >  { 
32-         let  mut  accum = HashSet :: new ( ) ; 
32+         let  mut  accum:   HashSet < usize >  = HashSet :: new ( ) ; 
3333        let  mut  mapping:  HashMap < usize ,  usize >  = HashMap :: new ( ) ; 
3434        self . optimize_plan ( plan,  & mut  accum,  & mut  mapping) 
3535    } 
@@ -47,6 +47,29 @@ impl ProjectionPushDown {
4747        mapping :  & mut  HashMap < usize ,  usize > , 
4848    )  -> Result < Rc < LogicalPlan > >  { 
4949        match  plan { 
50+             LogicalPlan :: Projection  { 
51+                 expr, 
52+                 input, 
53+                 schema, 
54+             }  => { 
55+                 // collect all columns referenced by projection expressions 
56+                 expr. iter ( ) . for_each ( |e| self . collect_expr ( e,  accum) ) ; 
57+ 
58+                 // push projection down 
59+                 let  input = self . optimize_plan ( & input,  accum,  mapping) ?; 
60+ 
61+                 // rewrite projection expressions to use new column indexes 
62+                 let  new_expr = expr
63+                     . iter ( ) 
64+                     . map ( |e| self . rewrite_expr ( e,  mapping) ) 
65+                     . collect :: < Result < Vec < Expr > > > ( ) ?; 
66+ 
67+                 Ok ( Rc :: new ( LogicalPlan :: Projection  { 
68+                     expr :  new_expr, 
69+                     input, 
70+                     schema :  schema. clone ( ) , 
71+                 } ) ) 
72+             } 
5073            LogicalPlan :: Selection  {  expr,  input }  => { 
5174                // collect all columns referenced by filter expression 
5275                self . collect_expr ( expr,  accum) ; 
@@ -92,6 +115,34 @@ impl ProjectionPushDown {
92115                    schema :  schema. clone ( ) , 
93116                } ) ) 
94117            } 
118+             LogicalPlan :: Sort  { 
119+                 expr, 
120+                 input, 
121+                 schema, 
122+             }  => { 
123+                 // collect all columns referenced by sort expressions 
124+                 expr. iter ( ) . for_each ( |e| self . collect_expr ( e,  accum) ) ; 
125+ 
126+                 // push projection down 
127+                 let  input = self . optimize_plan ( & input,  accum,  mapping) ?; 
128+ 
129+                 // rewrite sort expressions to use new column indexes 
130+                 let  new_expr = expr
131+                     . iter ( ) 
132+                     . map ( |e| self . rewrite_expr ( e,  mapping) ) 
133+                     . collect :: < Result < Vec < Expr > > > ( ) ?; 
134+ 
135+                 Ok ( Rc :: new ( LogicalPlan :: Sort  { 
136+                     expr :  new_expr, 
137+                     input, 
138+                     schema :  schema. clone ( ) , 
139+                 } ) ) 
140+             } 
141+             LogicalPlan :: EmptyRelation  {  schema }  => { 
142+                 Ok ( Rc :: new ( LogicalPlan :: EmptyRelation  { 
143+                     schema :  schema. clone ( ) , 
144+                 } ) ) 
145+             } 
95146            LogicalPlan :: TableScan  { 
96147                schema_name, 
97148                table_name, 
@@ -128,8 +179,6 @@ impl ProjectionPushDown {
128179                    projection :  Some ( projection) , 
129180                } ) ) 
130181            } 
131-             //TODO implement all logical plan variants and remove this unimplemented 
132-             _ => Err ( ArrowError :: ComputeError ( "unimplemented" . to_string ( ) ) ) , 
133182        } 
134183    } 
135184
@@ -227,20 +276,7 @@ mod tests {
227276
228277    #[ test]  
229278    fn  aggregate_no_group_by ( )  { 
230-         let  schema = Schema :: new ( vec ! [ 
231-             Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
232-             Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
233-             Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
234-         ] ) ; 
235- 
236-         // create unoptimized plan for SELECT MAX(b) FROM default.test 
237- 
238-         let  table_scan = TableScan  { 
239-             schema_name :  "default" . to_string ( ) , 
240-             table_name :  "test" . to_string ( ) , 
241-             schema :  Arc :: new ( schema) , 
242-             projection :  None , 
243-         } ; 
279+         let  table_scan = test_table_scan ( ) ; 
244280
245281        let  aggregate = Aggregate  { 
246282            group_expr :  vec ! [ ] , 
@@ -253,34 +289,12 @@ mod tests {
253289            input :  Rc :: new ( table_scan) , 
254290        } ; 
255291
256-         // run optimizer rule 
257- 
258-         let  rule:  Rc < RefCell < OptimizerRule > >  =
259-             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
260- 
261-         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
262- 
263-         let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
264- 
265-         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   TableScan: test projection=Some([1])" ) ; 
292+         assert_optimized_plan_eq ( & aggregate,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   TableScan: test projection=Some([1])" ) ; 
266293    } 
267294
268295    #[ test]  
269296    fn  aggregate_group_by ( )  { 
270-         let  schema = Schema :: new ( vec ! [ 
271-             Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
272-             Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
273-             Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
274-         ] ) ; 
275- 
276-         // create unoptimized plan for SELECT MAX(b) FROM default.test 
277- 
278-         let  table_scan = TableScan  { 
279-             schema_name :  "default" . to_string ( ) , 
280-             table_name :  "test" . to_string ( ) , 
281-             schema :  Arc :: new ( schema) , 
282-             projection :  None , 
283-         } ; 
297+         let  table_scan = test_table_scan ( ) ; 
284298
285299        let  aggregate = Aggregate  { 
286300            group_expr :  vec ! [ Column ( 2 ) ] , 
@@ -292,34 +306,12 @@ mod tests {
292306            input :  Rc :: new ( table_scan) , 
293307        } ; 
294308
295-         // run optimizer rule 
296- 
297-         let  rule:  Rc < RefCell < OptimizerRule > >  =
298-             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
299- 
300-         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
301- 
302-         let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
303- 
304-         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n   TableScan: test projection=Some([1, 2])" ) ; 
309+         assert_optimized_plan_eq ( & aggregate,  "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n   TableScan: test projection=Some([1, 2])" ) ; 
305310    } 
306311
307312    #[ test]  
308313    fn  aggregate_no_group_by_with_selection ( )  { 
309-         let  schema = Schema :: new ( vec ! [ 
310-             Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
311-             Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
312-             Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
313-         ] ) ; 
314- 
315-         // create unoptimized plan for SELECT MAX(b) FROM default.test 
316- 
317-         let  table_scan = TableScan  { 
318-             schema_name :  "default" . to_string ( ) , 
319-             table_name :  "test" . to_string ( ) , 
320-             schema :  Arc :: new ( schema) , 
321-             projection :  None , 
322-         } ; 
314+         let  table_scan = test_table_scan ( ) ; 
323315
324316        let  selection = Selection  { 
325317            expr :  Column ( 2 ) , 
@@ -337,15 +329,56 @@ mod tests {
337329            input :  Rc :: new ( selection) , 
338330        } ; 
339331
340-         // run optimizer rule 
332+         assert_optimized_plan_eq ( & aggregate,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   Selection: #1\n     TableScan: test projection=Some([1, 2])" ) ; 
333+     } 
341334
342-         let  rule:  Rc < RefCell < OptimizerRule > >  =
343-             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
335+     #[ test]  
336+     fn  cast ( )  { 
337+         let  table_scan = test_table_scan ( ) ; 
338+ 
339+         let  projection = Projection  { 
340+             expr :  vec ! [ Cast  { 
341+                 expr:  Rc :: new( Column ( 2 ) ) , 
342+                 data_type:  DataType :: Float64 , 
343+             } ] , 
344+             input :  Rc :: new ( table_scan) , 
345+             schema :  Arc :: new ( Schema :: new ( vec ! [ Field :: new( 
346+                 "CAST(c AS float)" , 
347+                 DataType :: Float64 , 
348+                 false , 
349+             ) ] ) ) , 
350+         } ; 
344351
345-         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
352+         assert_optimized_plan_eq ( 
353+             & projection, 
354+             "Projection: CAST(#0 AS Float64)\n   TableScan: test projection=Some([2])" , 
355+         ) ; 
356+     } 
346357
358+     fn  assert_optimized_plan_eq ( plan :  & LogicalPlan ,  expected :  & str )  { 
359+         let  optimized_plan = optimize ( plan) ; 
347360        let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
361+         assert_eq ! ( formatted_plan,  expected) ; 
362+     } 
348363
349-         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#0]]\n   Selection: #1\n     TableScan: test projection=Some([1, 2])" ) ; 
364+     fn  optimize ( plan :  & LogicalPlan )  -> Rc < LogicalPlan >  { 
365+         let  rule:  Rc < RefCell < OptimizerRule > >  =
366+             Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
367+         let  mut  borrowed_rule = rule. borrow_mut ( ) ; 
368+         borrowed_rule. optimize ( plan) . unwrap ( ) 
369+     } 
370+ 
371+     /// all tests share a common table 
372+ fn  test_table_scan ( )  -> LogicalPlan  { 
373+         TableScan  { 
374+             schema_name :  "default" . to_string ( ) , 
375+             table_name :  "test" . to_string ( ) , 
376+             schema :  Arc :: new ( Schema :: new ( vec ! [ 
377+                 Field :: new( "a" ,  DataType :: UInt32 ,  false ) , 
378+                 Field :: new( "b" ,  DataType :: UInt32 ,  false ) , 
379+                 Field :: new( "c" ,  DataType :: UInt32 ,  false ) , 
380+             ] ) ) , 
381+             projection :  None , 
382+         } 
350383    } 
351384} 
0 commit comments