Skip to content

Commit 92918dd

Browse files
committed
Implement projection push-down for selection and make projection deterministic
1 parent a80cfdf commit 92918dd

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

rust/datafusion/src/optimizer/projection_push_down.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ impl ProjectionPushDown {
4747
mapping: &mut HashMap<usize, usize>,
4848
) -> Result<Rc<LogicalPlan>> {
4949
match plan {
50+
LogicalPlan::Selection { expr, input } => {
51+
// collect all columns referenced by filter expression
52+
self.collect_expr(expr, accum);
53+
54+
// push projection down
55+
let input = self.optimize_plan(&input, accum, mapping)?;
56+
57+
// rewrite filter expression to use new column indexes
58+
let new_expr = self.rewrite_expr(expr, mapping);
59+
60+
Ok(Rc::new(LogicalPlan::Selection {
61+
expr: new_expr,
62+
input,
63+
}))
64+
}
5065
LogicalPlan::Aggregate {
5166
input,
5267
group_expr,
@@ -88,6 +103,9 @@ impl ProjectionPushDown {
88103
let mut projection: Vec<usize> = Vec::with_capacity(accum.len());
89104
accum.iter().for_each(|i| projection.push(*i));
90105

106+
// sort the projection otherwise we get non-deterministic behavior
107+
projection.sort();
108+
91109
// now that the table scan is returning a different schema we need to create a
92110
// mapping from the original column index to the new column index so that we
93111
// can rewrite expressions as we walk back up the tree
@@ -184,4 +202,88 @@ mod tests {
184202

185203
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#0]]\n TableScan: test projection=Some([1])");
186204
}
205+
206+
#[test]
207+
fn aggregate_group_by() {
208+
let schema = Schema::new(vec![
209+
Field::new("a", DataType::UInt32, false),
210+
Field::new("b", DataType::UInt32, false),
211+
Field::new("c", DataType::UInt32, false),
212+
]);
213+
214+
// create unoptimized plan for SELECT MAX(b) FROM default.test
215+
216+
let table_scan = TableScan {
217+
schema_name: "default".to_string(),
218+
table_name: "test".to_string(),
219+
schema: Arc::new(schema),
220+
projection: None,
221+
};
222+
223+
let aggregate = Aggregate {
224+
group_expr: vec![Column(2)],
225+
aggr_expr: vec![Column(1)],
226+
schema: Arc::new(Schema::new(vec![
227+
Field::new("c", DataType::UInt32, false),
228+
Field::new("MAX(b)", DataType::UInt32, false),
229+
])),
230+
input: Rc::new(table_scan),
231+
};
232+
233+
// run optimizer rule
234+
235+
let rule: Rc<RefCell<OptimizerRule>> =
236+
Rc::new(RefCell::new(ProjectionPushDown::new()));
237+
238+
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
239+
240+
let formatted_plan = format!("{:?}", optimized_plan);
241+
242+
assert_eq!(formatted_plan, "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n TableScan: test projection=Some([1, 2])");
243+
}
244+
245+
#[test]
246+
fn aggregate_no_group_by_with_selection() {
247+
let schema = Schema::new(vec![
248+
Field::new("a", DataType::UInt32, false),
249+
Field::new("b", DataType::UInt32, false),
250+
Field::new("c", DataType::UInt32, false),
251+
]);
252+
253+
// create unoptimized plan for SELECT MAX(b) FROM default.test
254+
255+
let table_scan = TableScan {
256+
schema_name: "default".to_string(),
257+
table_name: "test".to_string(),
258+
schema: Arc::new(schema),
259+
projection: None,
260+
};
261+
262+
let selection = Selection {
263+
expr: Column(2),
264+
input: Rc::new(table_scan),
265+
};
266+
267+
let aggregate = Aggregate {
268+
group_expr: vec![],
269+
aggr_expr: vec![Column(1)],
270+
schema: Arc::new(Schema::new(vec![Field::new(
271+
"MAX(b)",
272+
DataType::UInt32,
273+
false,
274+
)])),
275+
input: Rc::new(selection),
276+
};
277+
278+
// run optimizer rule
279+
280+
let rule: Rc<RefCell<OptimizerRule>> =
281+
Rc::new(RefCell::new(ProjectionPushDown::new()));
282+
283+
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
284+
285+
let formatted_plan = format!("{:?}", optimized_plan);
286+
287+
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#0]]\n Selection: #1\n TableScan: test projection=Some([1, 2])");
288+
}
187289
}

0 commit comments

Comments
 (0)