Skip to content

Commit a80cfdf

Browse files
committed
Implement mapping and expression rewrite logic
1 parent 26fd3b4 commit a80cfdf

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

rust/datafusion/src/optimizer/optimizer.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
//! Query optimizer traits
1919
2020
use crate::logicalplan::LogicalPlan;
21+
use arrow::error::Result;
2122
use std::rc::Rc;
2223

2324
/// An optimizer rules performs a transformation on a logical plan to produce an optimized logical plan.
2425
pub trait OptimizerRule {
25-
fn optimize(&mut self, plan: &LogicalPlan) -> Rc<LogicalPlan>;
26+
fn optimize(&mut self, plan: &LogicalPlan) -> Result<Rc<LogicalPlan>>;
2627
}

rust/datafusion/src/optimizer/projection_push_down.rs

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,18 @@
2020
use crate::logicalplan::Expr;
2121
use crate::logicalplan::LogicalPlan;
2222
use crate::optimizer::optimizer::OptimizerRule;
23-
use std::collections::HashSet;
23+
use arrow::error::{ArrowError, Result};
24+
use std::collections::{HashMap, HashSet};
2425
use std::rc::Rc;
2526

2627
/// Projection Push Down optimizer rule ensures that only referenced columns are loaded into memory
2728
pub struct ProjectionPushDown {}
2829

2930
impl OptimizerRule for ProjectionPushDown {
30-
fn optimize(&mut self, plan: &LogicalPlan) -> Rc<LogicalPlan> {
31+
fn optimize(&mut self, plan: &LogicalPlan) -> Result<Rc<LogicalPlan>> {
3132
let mut accum = HashSet::new();
32-
self.optimize_plan(plan, &mut accum)
33+
let mut mapping: HashMap<usize, usize> = HashMap::new();
34+
self.optimize_plan(plan, &mut accum, &mut mapping)
3335
}
3436
}
3537

@@ -42,7 +44,8 @@ impl ProjectionPushDown {
4244
&self,
4345
plan: &LogicalPlan,
4446
accum: &mut HashSet<usize>,
45-
) -> Rc<LogicalPlan> {
47+
mapping: &mut HashMap<usize, usize>,
48+
) -> Result<Rc<LogicalPlan>> {
4649
match plan {
4750
LogicalPlan::Aggregate {
4851
input,
@@ -54,12 +57,25 @@ impl ProjectionPushDown {
5457
group_expr.iter().for_each(|e| self.collect_expr(e, accum));
5558
aggr_expr.iter().for_each(|e| self.collect_expr(e, accum));
5659

57-
Rc::new(LogicalPlan::Aggregate {
58-
input: self.optimize_plan(&input, accum),
59-
group_expr: group_expr.clone(),
60-
aggr_expr: aggr_expr.clone(),
60+
// push projection down
61+
let input = self.optimize_plan(&input, accum, mapping)?;
62+
63+
// rewrite expressions to use new column indexes
64+
let new_group_expr: Vec<Expr> = group_expr
65+
.iter()
66+
.map(|e| self.rewrite_expr(e, mapping))
67+
.collect();
68+
let new_aggr_expr: Vec<Expr> = aggr_expr
69+
.iter()
70+
.map(|e| self.rewrite_expr(e, mapping))
71+
.collect();
72+
73+
Ok(Rc::new(LogicalPlan::Aggregate {
74+
input,
75+
group_expr: new_group_expr,
76+
aggr_expr: new_aggr_expr,
6177
schema: schema.clone(),
62-
})
78+
}))
6379
}
6480
LogicalPlan::TableScan {
6581
schema_name,
@@ -71,15 +87,31 @@ impl ProjectionPushDown {
7187
// the projection in the table scan
7288
let mut projection: Vec<usize> = Vec::with_capacity(accum.len());
7389
accum.iter().for_each(|i| projection.push(*i));
74-
Rc::new(LogicalPlan::TableScan {
90+
91+
// now that the table scan is returning a different schema we need to create a
92+
// mapping from the original column index to the new column index so that we
93+
// can rewrite expressions as we walk back up the tree
94+
95+
if mapping.len() != 0 {
96+
return Err(ArrowError::ComputeError("illegal state".to_string()));
97+
}
98+
99+
for i in 0..schema.fields().len() {
100+
if let Some(n) = projection.iter().position(|v| *v == i) {
101+
mapping.insert(i, n);
102+
}
103+
}
104+
105+
// return the table scan with projection
106+
Ok(Rc::new(LogicalPlan::TableScan {
75107
schema_name: schema_name.to_string(),
76108
table_name: table_name.to_string(),
77109
schema: schema.clone(),
78110
projection: Some(projection),
79-
})
111+
}))
80112
}
81113
//TODO implement all logical plan variants and remove this unimplemented
82-
_ => unimplemented!(),
114+
_ => Err(ArrowError::ComputeError("unimplemented".to_string())),
83115
}
84116
}
85117

@@ -92,6 +124,14 @@ impl ProjectionPushDown {
92124
_ => unimplemented!(),
93125
}
94126
}
127+
128+
fn rewrite_expr(&self, expr: &Expr, mapping: &HashMap<usize, usize>) -> Expr {
129+
match expr {
130+
Expr::Column(i) => Expr::Column(*mapping.get(i).unwrap()), //TODO error handling
131+
//TODO implement all expression variants and remove this unimplemented
132+
_ => unimplemented!(),
133+
}
134+
}
95135
}
96136

97137
#[cfg(test)]
@@ -138,10 +178,10 @@ mod tests {
138178
let rule: Rc<RefCell<OptimizerRule>> =
139179
Rc::new(RefCell::new(ProjectionPushDown::new()));
140180

141-
let optimized_plan = rule.borrow_mut().optimize(&aggregate);
181+
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
142182

143183
let formatted_plan = format!("{:?}", optimized_plan);
144184

145-
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#1]]\n TableScan: test projection=Some([1])");
185+
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#0]]\n TableScan: test projection=Some([1])");
146186
}
147187
}

0 commit comments

Comments
 (0)