Skip to content

Commit f959500

Browse files
committed
implement projection push down for rest of logical plan variants
1 parent 5fd5382 commit f959500

File tree

1 file changed

+103
-70
lines changed

1 file changed

+103
-70
lines changed

rust/datafusion/src/optimizer/projection_push_down.rs

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub struct ProjectionPushDown {}
2929

3030
impl OptimizerRule for ProjectionPushDown {
3131
fn optimize(&mut self, plan: &LogicalPlan) -> Result<Rc<LogicalPlan>> {
32-
let mut accum = HashSet::new();
32+
let mut accum: HashSet<usize> = HashSet::new();
3333
let mut mapping: HashMap<usize, usize> = HashMap::new();
3434
self.optimize_plan(plan, &mut accum, &mut mapping)
3535
}
@@ -47,6 +47,29 @@ impl ProjectionPushDown {
4747
mapping: &mut HashMap<usize, usize>,
4848
) -> Result<Rc<LogicalPlan>> {
4949
match plan {
50+
LogicalPlan::Projection {
51+
expr,
52+
input,
53+
schema,
54+
} => {
55+
// collect all columns referenced by projection expressions
56+
expr.iter().for_each(|e| self.collect_expr(e, accum));
57+
58+
// push projection down
59+
let input = self.optimize_plan(&input, accum, mapping)?;
60+
61+
// rewrite projection expressions to use new column indexes
62+
let new_expr = expr
63+
.iter()
64+
.map(|e| self.rewrite_expr(e, mapping))
65+
.collect::<Result<Vec<Expr>>>()?;
66+
67+
Ok(Rc::new(LogicalPlan::Projection {
68+
expr: new_expr,
69+
input,
70+
schema: schema.clone(),
71+
}))
72+
}
5073
LogicalPlan::Selection { expr, input } => {
5174
// collect all columns referenced by filter expression
5275
self.collect_expr(expr, accum);
@@ -92,6 +115,34 @@ impl ProjectionPushDown {
92115
schema: schema.clone(),
93116
}))
94117
}
118+
LogicalPlan::Sort {
119+
expr,
120+
input,
121+
schema,
122+
} => {
123+
// collect all columns referenced by sort expressions
124+
expr.iter().for_each(|e| self.collect_expr(e, accum));
125+
126+
// push projection down
127+
let input = self.optimize_plan(&input, accum, mapping)?;
128+
129+
// rewrite sort expressions to use new column indexes
130+
let new_expr = expr
131+
.iter()
132+
.map(|e| self.rewrite_expr(e, mapping))
133+
.collect::<Result<Vec<Expr>>>()?;
134+
135+
Ok(Rc::new(LogicalPlan::Sort {
136+
expr: new_expr,
137+
input,
138+
schema: schema.clone(),
139+
}))
140+
}
141+
LogicalPlan::EmptyRelation { schema } => {
142+
Ok(Rc::new(LogicalPlan::EmptyRelation {
143+
schema: schema.clone(),
144+
}))
145+
}
95146
LogicalPlan::TableScan {
96147
schema_name,
97148
table_name,
@@ -128,8 +179,6 @@ impl ProjectionPushDown {
128179
projection: Some(projection),
129180
}))
130181
}
131-
//TODO implement all logical plan variants and remove this unimplemented
132-
_ => Err(ArrowError::ComputeError("unimplemented".to_string())),
133182
}
134183
}
135184

@@ -227,20 +276,7 @@ mod tests {
227276

228277
#[test]
229278
fn aggregate_no_group_by() {
230-
let schema = Schema::new(vec![
231-
Field::new("a", DataType::UInt32, false),
232-
Field::new("b", DataType::UInt32, false),
233-
Field::new("c", DataType::UInt32, false),
234-
]);
235-
236-
// create unoptimized plan for SELECT MAX(b) FROM default.test
237-
238-
let table_scan = TableScan {
239-
schema_name: "default".to_string(),
240-
table_name: "test".to_string(),
241-
schema: Arc::new(schema),
242-
projection: None,
243-
};
279+
let table_scan = test_table_scan();
244280

245281
let aggregate = Aggregate {
246282
group_expr: vec![],
@@ -253,34 +289,12 @@ mod tests {
253289
input: Rc::new(table_scan),
254290
};
255291

256-
// run optimizer rule
257-
258-
let rule: Rc<RefCell<OptimizerRule>> =
259-
Rc::new(RefCell::new(ProjectionPushDown::new()));
260-
261-
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
262-
263-
let formatted_plan = format!("{:?}", optimized_plan);
264-
265-
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#0]]\n TableScan: test projection=Some([1])");
292+
assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[]], aggr=[[#0]]\n TableScan: test projection=Some([1])");
266293
}
267294

268295
#[test]
269296
fn aggregate_group_by() {
270-
let schema = Schema::new(vec![
271-
Field::new("a", DataType::UInt32, false),
272-
Field::new("b", DataType::UInt32, false),
273-
Field::new("c", DataType::UInt32, false),
274-
]);
275-
276-
// create unoptimized plan for SELECT MAX(b) FROM default.test
277-
278-
let table_scan = TableScan {
279-
schema_name: "default".to_string(),
280-
table_name: "test".to_string(),
281-
schema: Arc::new(schema),
282-
projection: None,
283-
};
297+
let table_scan = test_table_scan();
284298

285299
let aggregate = Aggregate {
286300
group_expr: vec![Column(2)],
@@ -292,34 +306,12 @@ mod tests {
292306
input: Rc::new(table_scan),
293307
};
294308

295-
// run optimizer rule
296-
297-
let rule: Rc<RefCell<OptimizerRule>> =
298-
Rc::new(RefCell::new(ProjectionPushDown::new()));
299-
300-
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
301-
302-
let formatted_plan = format!("{:?}", optimized_plan);
303-
304-
assert_eq!(formatted_plan, "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n TableScan: test projection=Some([1, 2])");
309+
assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n TableScan: test projection=Some([1, 2])");
305310
}
306311

307312
#[test]
308313
fn aggregate_no_group_by_with_selection() {
309-
let schema = Schema::new(vec![
310-
Field::new("a", DataType::UInt32, false),
311-
Field::new("b", DataType::UInt32, false),
312-
Field::new("c", DataType::UInt32, false),
313-
]);
314-
315-
// create unoptimized plan for SELECT MAX(b) FROM default.test
316-
317-
let table_scan = TableScan {
318-
schema_name: "default".to_string(),
319-
table_name: "test".to_string(),
320-
schema: Arc::new(schema),
321-
projection: None,
322-
};
314+
let table_scan = test_table_scan();
323315

324316
let selection = Selection {
325317
expr: Column(2),
@@ -337,15 +329,56 @@ mod tests {
337329
input: Rc::new(selection),
338330
};
339331

340-
// run optimizer rule
332+
assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[]], aggr=[[#0]]\n Selection: #1\n TableScan: test projection=Some([1, 2])");
333+
}
341334

342-
let rule: Rc<RefCell<OptimizerRule>> =
343-
Rc::new(RefCell::new(ProjectionPushDown::new()));
335+
#[test]
336+
fn cast() {
337+
let table_scan = test_table_scan();
338+
339+
let projection = Projection {
340+
expr: vec![Cast {
341+
expr: Rc::new(Column(2)),
342+
data_type: DataType::Float64,
343+
}],
344+
input: Rc::new(table_scan),
345+
schema: Arc::new(Schema::new(vec![Field::new(
346+
"CAST(c AS float)",
347+
DataType::Float64,
348+
false,
349+
)])),
350+
};
344351

345-
let optimized_plan = rule.borrow_mut().optimize(&aggregate).unwrap();
352+
assert_optimized_plan_eq(
353+
&projection,
354+
"Projection: CAST(#0 AS Float64)\n TableScan: test projection=Some([2])",
355+
);
356+
}
346357

358+
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
359+
let optimized_plan = optimize(plan);
347360
let formatted_plan = format!("{:?}", optimized_plan);
361+
assert_eq!(formatted_plan, expected);
362+
}
348363

349-
assert_eq!(formatted_plan, "Aggregate: groupBy=[[]], aggr=[[#0]]\n Selection: #1\n TableScan: test projection=Some([1, 2])");
364+
fn optimize(plan: &LogicalPlan) -> Rc<LogicalPlan> {
365+
let rule: Rc<RefCell<OptimizerRule>> =
366+
Rc::new(RefCell::new(ProjectionPushDown::new()));
367+
let mut borrowed_rule = rule.borrow_mut();
368+
borrowed_rule.optimize(plan).unwrap()
369+
}
370+
371+
/// all tests share a common table
372+
fn test_table_scan() -> LogicalPlan {
373+
TableScan {
374+
schema_name: "default".to_string(),
375+
table_name: "test".to_string(),
376+
schema: Arc::new(Schema::new(vec![
377+
Field::new("a", DataType::UInt32, false),
378+
Field::new("b", DataType::UInt32, false),
379+
Field::new("c", DataType::UInt32, false),
380+
])),
381+
projection: None,
382+
}
350383
}
351384
}

0 commit comments

Comments
 (0)