Skip to content

Commit 48ea4b2

Browse files
haohuaijinalambjonahgao
authored
fix: don't push down volatile predicates in projection (#7909)
* fix: don't push down volatile predicates in projection * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * add suggestions * fix * fix doc * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Jonah Gao <jonahgaox@gmail.com> * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Jonah Gao <jonahgaox@gmail.com> * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Jonah Gao <jonahgaox@gmail.com> * Update datafusion/optimizer/src/push_down_filter.rs Co-authored-by: Jonah Gao <jonahgaox@gmail.com> --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> Co-authored-by: Jonah Gao <jonahgaox@gmail.com>
1 parent d190ff1 commit 48ea4b2

File tree

1 file changed

+167
-27
lines changed

1 file changed

+167
-27
lines changed

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 167 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan
1616
1717
use crate::optimizer::ApplyOrder;
18-
use crate::utils::{conjunction, split_conjunction};
18+
use crate::utils::{conjunction, split_conjunction, split_conjunction_owned};
1919
use crate::{utils, OptimizerConfig, OptimizerRule};
2020
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
2121
use datafusion_common::{
2222
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
2323
};
2424
use datafusion_expr::expr::Alias;
25+
use datafusion_expr::Volatility;
2526
use datafusion_expr::{
2627
and,
2728
expr_rewriter::replace_col,
@@ -652,32 +653,60 @@ impl OptimizerRule for PushDownFilter {
652653
child_plan.with_new_inputs(&[new_filter])?
653654
}
654655
LogicalPlan::Projection(projection) => {
655-
// A projection is filter-commutable, but re-writes all predicate expressions
656+
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
657+
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
656658
// collect projection.
657-
let replace_map = projection
658-
.schema
659-
.fields()
660-
.iter()
661-
.enumerate()
662-
.map(|(i, field)| {
663-
// strip alias, as they should not be part of filters
664-
let expr = match &projection.expr[i] {
665-
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
666-
expr => expr.clone(),
667-
};
668-
669-
(field.qualified_name(), expr)
670-
})
671-
.collect::<HashMap<_, _>>();
659+
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) =
660+
projection
661+
.schema
662+
.fields()
663+
.iter()
664+
.enumerate()
665+
.map(|(i, field)| {
666+
// strip alias, as they should not be part of filters
667+
let expr = match &projection.expr[i] {
668+
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
669+
expr => expr.clone(),
670+
};
671+
672+
(field.qualified_name(), expr)
673+
})
674+
.partition(|(_, value)| is_volatile_expression(value));
672675

673-
// re-write all filters based on this projection
674-
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
675-
let new_filter = LogicalPlan::Filter(Filter::try_new(
676-
replace_cols_by_name(filter.predicate.clone(), &replace_map)?,
677-
projection.input.clone(),
678-
)?);
676+
let mut push_predicates = vec![];
677+
let mut keep_predicates = vec![];
678+
for expr in split_conjunction_owned(filter.predicate.clone()).into_iter()
679+
{
680+
if contain(&expr, &volatile_map) {
681+
keep_predicates.push(expr);
682+
} else {
683+
push_predicates.push(expr);
684+
}
685+
}
679686

680-
child_plan.with_new_inputs(&[new_filter])?
687+
match conjunction(push_predicates) {
688+
Some(expr) => {
689+
// re-write all filters based on this projection
690+
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
691+
let new_filter = LogicalPlan::Filter(Filter::try_new(
692+
replace_cols_by_name(expr, &non_volatile_map)?,
693+
projection.input.clone(),
694+
)?);
695+
696+
match conjunction(keep_predicates) {
697+
None => child_plan.with_new_inputs(&[new_filter])?,
698+
Some(keep_predicate) => {
699+
let child_plan =
700+
child_plan.with_new_inputs(&[new_filter])?;
701+
LogicalPlan::Filter(Filter::try_new(
702+
keep_predicate,
703+
Arc::new(child_plan),
704+
)?)
705+
}
706+
}
707+
}
708+
None => return Ok(None),
709+
}
681710
}
682711
LogicalPlan::Union(union) => {
683712
let mut inputs = Vec::with_capacity(union.inputs.len());
@@ -881,6 +910,42 @@ pub fn replace_cols_by_name(
881910
})
882911
}
883912

913+
/// check whether the expression is volatile predicates
914+
fn is_volatile_expression(e: &Expr) -> bool {
915+
let mut is_volatile = false;
916+
e.apply(&mut |expr| {
917+
Ok(match expr {
918+
Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => {
919+
is_volatile = true;
920+
VisitRecursion::Stop
921+
}
922+
_ => VisitRecursion::Continue,
923+
})
924+
})
925+
.unwrap();
926+
is_volatile
927+
}
928+
929+
/// check whether the expression uses the columns in `check_map`.
930+
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
931+
let mut is_contain = false;
932+
e.apply(&mut |expr| {
933+
Ok(if let Expr::Column(c) = &expr {
934+
match check_map.get(&c.flat_name()) {
935+
Some(_) => {
936+
is_contain = true;
937+
VisitRecursion::Stop
938+
}
939+
None => VisitRecursion::Continue,
940+
}
941+
} else {
942+
VisitRecursion::Continue
943+
})
944+
})
945+
.unwrap();
946+
is_contain
947+
}
948+
884949
#[cfg(test)]
885950
mod tests {
886951
use super::*;
@@ -893,9 +958,9 @@ mod tests {
893958
use datafusion_common::{DFSchema, DFSchemaRef};
894959
use datafusion_expr::logical_plan::table_scan;
895960
use datafusion_expr::{
896-
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
897-
Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
898-
UserDefinedLogicalNodeCore,
961+
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum,
962+
BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource,
963+
TableType, UserDefinedLogicalNodeCore,
899964
};
900965
use std::fmt::{Debug, Formatter};
901966
use std::sync::Arc;
@@ -2712,4 +2777,79 @@ Projection: a, b
27122777
\n TableScan: test2";
27132778
assert_optimized_plan_eq(&plan, expected)
27142779
}
2780+
2781+
#[test]
2782+
fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
2783+
// SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
2784+
let table_scan = test_table_scan_with_name("test1")?;
2785+
let plan = LogicalPlanBuilder::from(table_scan)
2786+
.aggregate(vec![col("a")], vec![sum(col("b"))])?
2787+
.project(vec![
2788+
col("a"),
2789+
sum(col("b")),
2790+
add(random(), lit(1)).alias("r"),
2791+
])?
2792+
.alias("t")?
2793+
.filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
2794+
.project(vec![col("t.a"), col("t.r")])?
2795+
.build()?;
2796+
2797+
let expected_before = "Projection: t.a, t.r\
2798+
\n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\
2799+
\n SubqueryAlias: t\
2800+
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
2801+
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
2802+
\n TableScan: test1";
2803+
assert_eq!(format!("{plan:?}"), expected_before);
2804+
2805+
let expected_after = "Projection: t.a, t.r\
2806+
\n SubqueryAlias: t\
2807+
\n Filter: r > Float64(0.5)\
2808+
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
2809+
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
2810+
\n TableScan: test1, full_filters=[test1.a > Int32(5)]";
2811+
assert_optimized_plan_eq(&plan, expected_after)
2812+
}
2813+
2814+
#[test]
2815+
fn test_push_down_volatile_function_in_join() -> Result<()> {
2816+
// SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
2817+
let table_scan = test_table_scan_with_name("test1")?;
2818+
let left = LogicalPlanBuilder::from(table_scan).build()?;
2819+
let right_table_scan = test_table_scan_with_name("test2")?;
2820+
let right = LogicalPlanBuilder::from(right_table_scan).build()?;
2821+
let plan = LogicalPlanBuilder::from(left)
2822+
.join(
2823+
right,
2824+
JoinType::Inner,
2825+
(
2826+
vec![Column::from_qualified_name("test1.a")],
2827+
vec![Column::from_qualified_name("test2.a")],
2828+
),
2829+
None,
2830+
)?
2831+
.project(vec![col("test1.a").alias("a"), random().alias("r")])?
2832+
.alias("t")?
2833+
.filter(col("t.r").gt(lit(0.8)))?
2834+
.project(vec![col("t.a"), col("t.r")])?
2835+
.build()?;
2836+
2837+
let expected_before = "Projection: t.a, t.r\
2838+
\n Filter: t.r > Float64(0.8)\
2839+
\n SubqueryAlias: t\
2840+
\n Projection: test1.a AS a, random() AS r\
2841+
\n Inner Join: test1.a = test2.a\
2842+
\n TableScan: test1\
2843+
\n TableScan: test2";
2844+
assert_eq!(format!("{plan:?}"), expected_before);
2845+
2846+
let expected = "Projection: t.a, t.r\
2847+
\n SubqueryAlias: t\
2848+
\n Filter: r > Float64(0.8)\
2849+
\n Projection: test1.a AS a, random() AS r\
2850+
\n Inner Join: test1.a = test2.a\
2851+
\n TableScan: test1\
2852+
\n TableScan: test2";
2853+
assert_optimized_plan_eq(&plan, expected)
2854+
}
27152855
}

0 commit comments

Comments
 (0)