Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): support column pruning for some operators #653

Merged
merged 8 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl PlanRewriter for BoolExprSimplificationRule {
let child = self.rewrite(plan.child());
let new_plan = Arc::new(plan.clone_with_rewrite_expr(child, self));
match &new_plan.expr() {
Constant(Bool(false) | Null) => Arc::new(Dummy {}),
Constant(Bool(false) | Null) => Arc::new(Dummy::new(new_plan.schema())),
Constant(Bool(true)) => return plan.child().clone(),
_ => new_plan,
}
Expand Down
17 changes: 15 additions & 2 deletions src/optimizer/plan_nodes/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@ use super::*;

/// A dummy plan.
#[derive(Debug, Clone, Serialize)]
pub struct Dummy {}
pub struct Dummy {
schema: Vec<ColumnDesc>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Dummy needs schema?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it was just a bug and I fixed it by the way
In SQL: SELECT x FROM test WHERE 1 > 3.
the projection operator schema() requires the schema of the following operator, schema() will panic if dummy schema hasn't schema.

}

impl Dummy {
pub fn new(schema: Vec<ColumnDesc>) -> Self {
Self { schema }
}
}

impl PlanTreeNodeLeaf for Dummy {}
impl_plan_tree_node_for_leaf!(Dummy);
impl PlanNode for Dummy {}
impl PlanNode for Dummy {
fn schema(&self) -> Vec<ColumnDesc> {
self.schema.clone()
}
}
impl fmt::Display for Dummy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Dummy:")
Expand Down
14 changes: 14 additions & 0 deletions src/optimizer/plan_nodes/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ impl PlanNode for Internal {
fn schema(&self) -> Vec<ColumnDesc> {
self.column_descs.clone()
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
let (column_ids, column_descs) = required_cols
.iter()
.map(|col_idx| (self.column_ids[col_idx], self.column_descs[col_idx].clone()))
.unzip();
Internal::new(
self.table_name.clone(),
self.table_ref_id,
column_ids,
column_descs,
)
.into_plan_ref()
}
}

impl fmt::Display for Internal {
Expand Down
158 changes: 156 additions & 2 deletions src/optimizer/plan_nodes/logical_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::fmt;
use serde::Serialize;

use super::*;
use crate::binder::{BoundAggCall, BoundExpr};
use crate::binder::{BoundAggCall, BoundExpr, ExprVisitor};
use crate::optimizer::logical_plan_rewriter::ExprRewriter;

/// The logical plan of hash aggregate operation.
Expand Down Expand Up @@ -90,6 +90,76 @@ impl PlanNode for LogicalAggregate {
fn estimated_cardinality(&self) -> usize {
self.child().estimated_cardinality()
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
let group_keys_len = self.group_keys.len();

// Collect ref_idx of AggCall args
let mut visitor =
CollectRequiredCols(BitSet::with_capacity(group_keys_len + self.agg_calls.len()));
let mut new_agg_calls: Vec<_> = required_cols
.iter()
.filter(|&index| index >= group_keys_len)
.map(|index| {
let call = &self.agg_calls[index - group_keys_len];
call.args.iter().for_each(|expr| {
visitor.visit_expr(expr);
});
self.agg_calls[index - group_keys_len].clone()
})
.collect();

// Collect ref_idx of GroupExpr
self.group_keys
.iter()
.for_each(|group| visitor.visit_expr(group));

let input_cols = visitor.0;

let mapper = Mapper::new_with_bitset(&input_cols);
for call in &mut new_agg_calls {
call.args.iter_mut().for_each(|expr| {
mapper.rewrite_expr(expr);
})
}

let mut group_keys = self.group_keys.clone();
group_keys
.iter_mut()
.for_each(|expr| mapper.rewrite_expr(expr));

let new_agg = LogicalAggregate::new(
new_agg_calls.clone(),
group_keys,
self.child.prune_col(input_cols),
);

let bitset = BitSet::from_iter(0..group_keys_len);

if bitset.is_subset(&required_cols) {
new_agg.into_plan_ref()
} else {
// Need prune
let mut new_projection: Vec<BoundExpr> = required_cols
.iter()
.filter(|&i| i < group_keys_len)
.map(|index| {
BoundExpr::InputRef(BoundInputRef {
index,
return_type: self.group_keys[index].return_type().unwrap(),
})
})
.collect();

for (index, item) in new_agg_calls.iter().enumerate() {
new_projection.push(BoundExpr::InputRef(BoundInputRef {
index: group_keys_len + index,
return_type: item.return_type.clone(),
}))
}
LogicalProjection::new(new_projection, new_agg.into_plan_ref()).into_plan_ref()
}
}
}
impl fmt::Display for LogicalAggregate {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -129,7 +199,7 @@ mod tests {
},
],
vec![],
Arc::new(Dummy {}),
Arc::new(Dummy::new(Vec::new())),
);

let column_names = plan.out_names();
Expand All @@ -138,4 +208,88 @@ mod tests {
assert_eq!(column_names[2], "count");
assert_eq!(column_names[3], "count");
}

#[test]
/// Pruning
/// ```text
/// Agg(gk: input_ref(2), call: sum(input_ref(0)), avg(input_ref(1)))
/// TableScan(v1, v2, v3)
/// ```
/// with required columns [2] will result in
/// ```text
/// Projection(input_ref(1))
/// Agg(gk: input_ref(1), call: avg(input_ref(0)))
/// TableScan(v1, v3)
/// ```
fn test_prune_aggregate() {
let ty = DataTypeKind::Int(None).not_null();
let col_descs = vec![
ty.clone().to_column("v1".into()),
ty.clone().to_column("v2".into()),
ty.clone().to_column("v3".into()),
];

let table_scan = LogicalTableScan::new(
crate::catalog::TableRefId {
database_id: 0,
schema_id: 0,
table_id: 0,
},
vec![1, 2, 3],
col_descs,
false,
false,
None,
);

let input_refs = vec![
BoundExpr::InputRef(BoundInputRef {
index: 0,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 1,
return_type: ty.clone(),
}),
BoundExpr::InputRef(BoundInputRef {
index: 2,
return_type: ty,
}),
];

let aggregate = LogicalAggregate::new(
vec![
BoundAggCall {
kind: AggKind::Sum,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int(None).not_null(),
},
BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[1].clone()],
return_type: DataTypeKind::Int(None).not_null(),
},
],
vec![input_refs[2].clone()],
Arc::new(table_scan),
);

let mut required_cols = BitSet::new();
required_cols.insert(2);
let plan = aggregate.prune_col(required_cols);
let plan = plan.as_logical_projection().unwrap();

assert_eq!(plan.project_expressions(), vec![input_refs[1].clone()]);
let plan = plan.child();
let plan = plan.as_logical_aggregate().unwrap();

assert_eq!(
plan.agg_calls(),
vec![BoundAggCall {
kind: AggKind::Avg,
args: vec![input_refs[0].clone()],
return_type: DataTypeKind::Int(None).not_null(),
}]
);
}
}
8 changes: 7 additions & 1 deletion src/optimizer/plan_nodes/logical_copy_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ impl PlanTreeNodeUnary for LogicalCopyToFile {
}
}
impl_plan_tree_node_for_unary!(LogicalCopyToFile);
impl PlanNode for LogicalCopyToFile {}
impl PlanNode for LogicalCopyToFile {
fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = (0..self.child().out_types().len()).into_iter().collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalCopyToFile {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/plan_nodes/logical_delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ impl PlanNode for LogicalDelete {
false,
)]
}

fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = (0..self.child().out_types().len()).into_iter().collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalDelete {
Expand Down
35 changes: 7 additions & 28 deletions src/optimizer/plan_nodes/logical_filter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright 2022 RisingLight Project Authors. Licensed under Apache-2.0.

use std::collections::HashMap;
use std::fmt;

use serde::Serialize;
Expand Down Expand Up @@ -56,51 +55,31 @@ impl PlanNode for LogicalFilter {
}

fn prune_col(&self, required_cols: BitSet) -> PlanRef {
struct CollectRequiredCols(BitSet);
impl ExprVisitor for CollectRequiredCols {
fn visit_input_ref(&mut self, expr: &BoundInputRef) {
self.0.insert(expr.index);
}
}
let mut visitor = CollectRequiredCols(required_cols.clone());
visitor.visit_expr(&self.expr);
let input_cols = visitor.0;

let mut idx_table = HashMap::new();
for (new_idx, old_idx) in input_cols.iter().enumerate() {
idx_table.insert(old_idx, new_idx);
}

struct Mapper(HashMap<usize, usize>);
impl ExprRewriter for Mapper {
fn rewrite_input_ref(&self, expr: &mut BoundExpr) {
match expr {
BoundExpr::InputRef(ref mut input_ref) => {
input_ref.index = self.0[&input_ref.index];
}
_ => unreachable!(),
}
}
}

let mut expr = self.expr.clone();
Mapper(idx_table.clone()).rewrite_expr(&mut expr);
let mapper = Mapper::new_with_bitset(&input_cols);
mapper.rewrite_expr(&mut expr);

let need_prune = required_cols != input_cols;
let new_filter = Self {
expr,
child: self.child.prune_col(input_cols.clone()),
child: self.child.prune_col(input_cols),
}
.into_plan_ref();

if required_cols == input_cols {
if !need_prune {
return new_filter;
}

let input_types = self.out_types();
let exprs = required_cols
.iter()
.map(|old_idx| {
BoundExpr::InputRef(BoundInputRef {
index: idx_table[&old_idx],
index: mapper[old_idx],
return_type: input_types[old_idx].clone(),
})
})
Expand Down
10 changes: 10 additions & 0 deletions src/optimizer/plan_nodes/logical_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ impl PlanNode for LogicalInsert {
false,
)]
}

fn prune_col(&self, _required_cols: BitSet) -> PlanRef {
let input_cols = self
.column_ids
.iter()
.map(|&column_id| column_id as usize)
.collect();
self.clone_with_child(self.child.prune_col(input_cols))
.into_plan_ref()
}
}

impl fmt::Display for LogicalInsert {
Expand Down
Loading