2020use  crate :: logicalplan:: Expr ; 
2121use  crate :: logicalplan:: LogicalPlan ; 
2222use  crate :: optimizer:: optimizer:: OptimizerRule ; 
23- use  std:: collections:: HashSet ; 
23+ use  arrow:: error:: { ArrowError ,  Result } ; 
24+ use  std:: collections:: { HashMap ,  HashSet } ; 
2425use  std:: rc:: Rc ; 
2526
2627/// Projection Push Down optimizer rule ensures that only referenced columns are loaded into memory 
2728pub  struct  ProjectionPushDown  { } 
2829
2930impl  OptimizerRule  for  ProjectionPushDown  { 
30-     fn  optimize ( & mut  self ,  plan :  & LogicalPlan )  -> Rc < LogicalPlan >  { 
31+     fn  optimize ( & mut  self ,  plan :  & LogicalPlan )  -> Result < Rc < LogicalPlan > >  { 
3132        let  mut  accum = HashSet :: new ( ) ; 
32-         self . optimize_plan ( plan,  & mut  accum) 
33+         let  mut  mapping:  HashMap < usize ,  usize >  = HashMap :: new ( ) ; 
34+         self . optimize_plan ( plan,  & mut  accum,  & mut  mapping) 
3335    } 
3436} 
3537
@@ -42,7 +44,8 @@ impl ProjectionPushDown {
4244        & self , 
4345        plan :  & LogicalPlan , 
4446        accum :  & mut  HashSet < usize > , 
45-     )  -> Rc < LogicalPlan >  { 
47+         mapping :  & mut  HashMap < usize ,  usize > , 
48+     )  -> Result < Rc < LogicalPlan > >  { 
4649        match  plan { 
4750            LogicalPlan :: Aggregate  { 
4851                input, 
@@ -54,12 +57,25 @@ impl ProjectionPushDown {
5457                group_expr. iter ( ) . for_each ( |e| self . collect_expr ( e,  accum) ) ; 
5558                aggr_expr. iter ( ) . for_each ( |e| self . collect_expr ( e,  accum) ) ; 
5659
57-                 Rc :: new ( LogicalPlan :: Aggregate  { 
58-                     input :  self . optimize_plan ( & input,  accum) , 
59-                     group_expr :  group_expr. clone ( ) , 
60-                     aggr_expr :  aggr_expr. clone ( ) , 
60+                 // push projection down 
61+                 let  input = self . optimize_plan ( & input,  accum,  mapping) ?; 
62+ 
63+                 // rewrite expressions to use new column indexes 
64+                 let  new_group_expr:  Vec < Expr >  = group_expr
65+                     . iter ( ) 
66+                     . map ( |e| self . rewrite_expr ( e,  mapping) ) 
67+                     . collect ( ) ; 
68+                 let  new_aggr_expr:  Vec < Expr >  = aggr_expr
69+                     . iter ( ) 
70+                     . map ( |e| self . rewrite_expr ( e,  mapping) ) 
71+                     . collect ( ) ; 
72+ 
73+                 Ok ( Rc :: new ( LogicalPlan :: Aggregate  { 
74+                     input, 
75+                     group_expr :  new_group_expr, 
76+                     aggr_expr :  new_aggr_expr, 
6177                    schema :  schema. clone ( ) , 
62-                 } ) 
78+                 } ) ) 
6379            } 
6480            LogicalPlan :: TableScan  { 
6581                schema_name, 
@@ -71,15 +87,31 @@ impl ProjectionPushDown {
7187                // the projection in the table scan 
7288                let  mut  projection:  Vec < usize >  = Vec :: with_capacity ( accum. len ( ) ) ; 
7389                accum. iter ( ) . for_each ( |i| projection. push ( * i) ) ; 
74-                 Rc :: new ( LogicalPlan :: TableScan  { 
90+ 
91+                 // now that the table scan is returning a different schema we need to create a 
92+                 // mapping from the original column index to the new column index so that we 
93+                 // can rewrite expressions as we walk back up the tree 
94+ 
95+                 if  mapping. len ( )  != 0  { 
96+                     return  Err ( ArrowError :: ComputeError ( "illegal state" . to_string ( ) ) ) ; 
97+                 } 
98+ 
99+                 for  i in  0 ..schema. fields ( ) . len ( )  { 
100+                     if  let  Some ( n)  = projection. iter ( ) . position ( |v| * v == i)  { 
101+                         mapping. insert ( i,  n) ; 
102+                     } 
103+                 } 
104+ 
105+                 // return the table scan with projection 
106+                 Ok ( Rc :: new ( LogicalPlan :: TableScan  { 
75107                    schema_name :  schema_name. to_string ( ) , 
76108                    table_name :  table_name. to_string ( ) , 
77109                    schema :  schema. clone ( ) , 
78110                    projection :  Some ( projection) , 
79-                 } ) 
111+                 } ) ) 
80112            } 
81113            //TODO implement all logical plan variants and remove this unimplemented 
82-             _ => unimplemented ! ( ) , 
114+             _ => Err ( ArrowError :: ComputeError ( " unimplemented" . to_string ( ) ) ) , 
83115        } 
84116    } 
85117
@@ -92,6 +124,14 @@ impl ProjectionPushDown {
92124            _ => unimplemented ! ( ) , 
93125        } 
94126    } 
127+ 
128+     fn  rewrite_expr ( & self ,  expr :  & Expr ,  mapping :  & HashMap < usize ,  usize > )  -> Expr  { 
129+         match  expr { 
130+             Expr :: Column ( i)  => Expr :: Column ( * mapping. get ( i) . unwrap ( ) ) ,  //TODO error handling 
131+             //TODO implement all expression variants and remove this unimplemented 
132+             _ => unimplemented ! ( ) , 
133+         } 
134+     } 
95135} 
96136
97137#[ cfg( test) ]  
@@ -138,10 +178,10 @@ mod tests {
138178        let  rule:  Rc < RefCell < OptimizerRule > >  =
139179            Rc :: new ( RefCell :: new ( ProjectionPushDown :: new ( ) ) ) ; 
140180
141-         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) ; 
181+         let  optimized_plan = rule. borrow_mut ( ) . optimize ( & aggregate) . unwrap ( ) ; 
142182
143183        let  formatted_plan = format ! ( "{:?}" ,  optimized_plan) ; 
144184
145-         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#1 ]]\n   TableScan: test projection=Some([1])" ) ; 
185+         assert_eq ! ( formatted_plan,  "Aggregate: groupBy=[[]], aggr=[[#0 ]]\n   TableScan: test projection=Some([1])" ) ; 
146186    } 
147187} 
0 commit comments