Skip to content

Commit

Permalink
feat(optimizer): support column pruning for some operators (#653)
Browse files Browse the repository at this point in the history
  • Loading branch information
lokax authored Jul 29, 2022
1 parent ddf9971 commit c859dff
Show file tree
Hide file tree
Showing 17 changed files with 821 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl PlanRewriter for BoolExprSimplificationRule {
let child = self.rewrite(plan.child());
let new_plan = Arc::new(plan.clone_with_rewrite_expr(child, self));
match &new_plan.expr() {
Constant(Bool(false) | Null) => Arc::new(Dummy {}),
Constant(Bool(false) | Null) => Arc::new(Dummy::new(new_plan.schema())),
Constant(Bool(true)) => return plan.child().clone(),
_ => new_plan,
}
Expand Down
17 changes: 15 additions & 2 deletions src/optimizer/plan_nodes/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@ use super::*;

/// A dummy plan.
#[derive(Debug, Clone, Serialize)]
pub struct Dummy {}
pub struct Dummy {
schema: Vec<ColumnDesc>,
}

impl Dummy {
pub fn new(schema: Vec<ColumnDesc>) -> Self {
Self { schema }
}
}

impl PlanTreeNodeLeaf for Dummy {}
impl_plan_tree_node_for_leaf!(Dummy);
impl PlanNode for Dummy {}
impl PlanNode for Dummy {
fn schema(&self) -> Vec<ColumnDesc> {
self.schema.clone()
}
}
impl fmt::Display for Dummy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Dummy:")
Expand Down
14 changes: 14 additions & 0 deletions src/optimizer/plan_nodes/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ impl PlanNode for Internal {
fn schema(&self) -> Vec<ColumnDesc> {
self.column_descs.clone()
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
let (column_ids, column_descs) = required_cols
.iter()
.map(|col_idx| (self.column_ids[col_idx], self.column_descs[col_idx].clone()))
.unzip();
Internal::new(
self.table_name.clone(),
self.table_ref_id,
column_ids,
column_descs,
)
.into_plan_ref()
}
}

impl fmt::Display for Internal {
Expand Down
158 changes: 156 additions & 2 deletions src/optimizer/plan_nodes/logical_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::fmt;
use serde::Serialize;

use super::*;
use crate::binder::{BoundAggCall, BoundExpr};
use crate::binder::{BoundAggCall, BoundExpr, ExprVisitor};
use crate::optimizer::logical_plan_rewriter::ExprRewriter;

/// The logical plan of hash aggregate operation.
Expand Down Expand Up @@ -90,6 +90,76 @@ impl PlanNode for LogicalAggregate {
fn estimated_cardinality(&self) -> usize {
self.child().estimated_cardinality()
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
let group_keys_len = self.group_keys.len();

// Collect ref_idx of AggCall args
let mut visitor =
CollectRequiredCols(BitSet::with_capacity(group_keys_len + self.agg_calls.len()));
let mut new_agg_calls: Vec<_> = required_cols
.iter()
.filter(|&index| index >= group_keys_len)
.map(|index| {
let call = &self.agg_calls[index - group_keys_len];
call.args.iter().for_each(|expr| {
visitor.visit_expr(expr);
});
self.agg_calls[index - group_keys_len].clone()
})
.collect();

// Collect ref_idx of GroupExpr
self.group_keys
.iter()
.for_each(|group| visitor.visit_expr(group));

let input_cols = visitor.0;

let mapper = Mapper::new_with_bitset(&input_cols);
for call in &mut new_agg_calls {
call.args.iter_mut().for_each(|expr| {
mapper.rewrite_expr(expr);
})
}

let mut group_keys = self.group_keys.clone();
group_keys
.iter_mut()
.for_each(|expr| mapper.rewrite_expr(expr));

let new_agg = LogicalAggregate::new(
new_agg_calls.clone(),
group_keys,
self.child.prune_col(input_cols),
);

let bitset = BitSet::from_iter(0..group_keys_len);

if bitset.is_subset(&required_cols) {
new_agg.into_plan_ref()
} else {
// Need prune
let mut new_projection: Vec<BoundExpr> = required_cols
.iter()
.filter(|&i| i < group_keys_len)
.map(|index| {
BoundExpr::InputRef(BoundInputRef {
index,
return_type: self.group_keys[index].return_type().unwrap(),
})
})
.collect();

for (index, item) in new_agg_calls.iter().enumerate() {
new_projection.push(BoundExpr::InputRef(BoundInputRef {
index: group_keys_len + index,
return_type: item.return_type.clone(),
}))
}
LogicalProjection::new(new_projection, new_agg.into_plan_ref()).into_plan_ref()
}
}
}
impl fmt::Display for LogicalAggregate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -129,7 +199,7 @@ mod tests {
},
],
vec![],
Arc::new(Dummy {}),
Arc::new(Dummy::new(Vec::new())),
);

let column_names = plan.out_names();
Expand All @@ -138,4 +208,88 @@ mod tests {
assert_eq!(column_names[2], "count");
assert_eq!(column_names[3], "count");
}

#[test]
/// Pruning
/// ```text
/// Agg(gk: input_ref(2), call: sum(input_ref(0)), avg(input_ref(1)))
/// TableScan(v1, v2, v3)
/// ```
/// with required columns [2] will result in
/// ```text
/// Projection(input_ref(1))
/// Agg(gk: input_ref(1), call: avg(input_ref(0)))
/// TableScan(v1, v3)
/// ```
fn test_prune_aggregate() {
let ty = DataTypeKind::Int(None).not_null();
let col_descs = vec![
ty.clone().to_column("v1".into()),
ty.clone().to_column("v2".into()),
ty.clone().to_column("v3".into()),
];

let table_scan = LogicalTableScan::new(
crate::catalog::TableRefId {
database_id: 0,
schema_id: 0,
table_id: 0,
},
vec![1, 2, 3],
col_descs,
false,
false,
None,
);

let input_refs = vec![
BoundExpr::InputRef(BoundInputRef {
index: 0,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 1,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 2,
return_type: ty,
}),
];

let aggregate = LogicalAggregate::new(
vec![
BoundAggCall {
kind: AggKind::Sum,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int(None).not_null(),
},
BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[1].clone()],
return_type: DataTypeKind::Int(None).not_null(),
},
],
vec![input_refs[2].clone()],
Arc::new(table_scan),
);

let mut required_cols = BitSet::new();
required_cols.insert(2);
let plan = aggregate.prune_col(required_cols);
let plan = plan.as_logical_projection().unwrap();

assert_eq!(plan.project_expressions(), vec![input_refs[1].clone()]);
let plan = plan.child();
let plan = plan.as_logical_aggregate().unwrap();

assert_eq!(
plan.agg_calls(),
vec![BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int(None).not_null(),
}]
);
}
}
8 changes: 7 additions & 1 deletion src/optimizer/plan_nodes/logical_copy_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ impl PlanTreeNodeUnary for LogicalCopyToFile {
}
}
impl_plan_tree_node_for_unary!(LogicalCopyToFile);
impl PlanNode for LogicalCopyToFile {}
impl PlanNode for LogicalCopyToFile {
fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = (0..self.child().out_types().len()).into_iter().collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalCopyToFile {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/plan_nodes/logical_delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ impl PlanNode for LogicalDelete {
false,
)]
}

fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = (0..self.child().out_types().len()).into_iter().collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalDelete {
Expand Down
35 changes: 7 additions & 28 deletions src/optimizer/plan_nodes/logical_filter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0.

use std::collections::HashMap;
use std::fmt;

use serde::Serialize;
Expand Down Expand Up @@ -56,51 +55,31 @@ impl PlanNode for LogicalFilter {
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
struct CollectRequiredCols(BitSet);
impl ExprVisitor for CollectRequiredCols {
fn visit_input_ref(&mut self, expr: &BoundInputRef) {
self.0.insert(expr.index);
}
}
let mut visitor = CollectRequiredCols(required_cols.clone());
visitor.visit_expr(&self.expr);
let input_cols = visitor.0;

let mut idx_table = HashMap::new();
for (new_idx, old_idx) in input_cols.iter().enumerate() {
idx_table.insert(old_idx, new_idx);
}

struct Mapper(HashMap<usize, usize>);
impl ExprRewriter for Mapper {
fn rewrite_input_ref(&self, expr: &mut BoundExpr) {
match expr {
BoundExpr::InputRef(ref mut input_ref) => {
input_ref.index = self.0[&input_ref.index];
}
_ => unreachable!(),
}
}
}

let mut expr = self.expr.clone();
Mapper(idx_table.clone()).rewrite_expr(&mut expr);
let mapper = Mapper::new_with_bitset(&input_cols);
mapper.rewrite_expr(&mut expr);

let need_prune = required_cols != input_cols;
let new_filter = Self {
expr,
child: self.child.prune_col(input_cols.clone()),
child: self.child.prune_col(input_cols),
}
.into_plan_ref();

if required_cols == input_cols {
if !need_prune {
return new_filter;
}

let input_types = self.out_types();
let exprs = required_cols
.iter()
.map(|old_idx| {
BoundExpr::InputRef(BoundInputRef {
index: idx_table[&old_idx],
index: mapper[old_idx],
return_type: input_types[old_idx].clone(),
})
})
Expand Down
10 changes: 10 additions & 0 deletions src/optimizer/plan_nodes/logical_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ impl PlanNode for LogicalInsert {
false,
)]
}

fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = self
.column_ids
.iter()
.map(|&column_id| column_id as usize)
.collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalInsert {
Expand Down
Loading

0 comments on commit c859dff

Please sign in to comment.