1515//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan
1616
1717use crate :: optimizer:: ApplyOrder ;
18- use crate :: utils:: { conjunction, split_conjunction} ;
18+ use crate :: utils:: { conjunction, split_conjunction, split_conjunction_owned } ;
1919use crate :: { utils, OptimizerConfig , OptimizerRule } ;
2020use datafusion_common:: tree_node:: { Transformed , TreeNode , VisitRecursion } ;
2121use datafusion_common:: {
2222 internal_err, plan_datafusion_err, Column , DFSchema , DataFusionError , Result ,
2323} ;
2424use datafusion_expr:: expr:: Alias ;
25+ use datafusion_expr:: Volatility ;
2526use datafusion_expr:: {
2627 and,
2728 expr_rewriter:: replace_col,
@@ -652,32 +653,60 @@ impl OptimizerRule for PushDownFilter {
652653 child_plan. with_new_inputs ( & [ new_filter] ) ?
653654 }
654655 LogicalPlan :: Projection ( projection) => {
655- // A projection is filter-commutable, but re-writes all predicate expressions
656+ // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
657+ // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
656658 // collect projection.
657- let replace_map = projection
658- . schema
659- . fields ( )
660- . iter ( )
661- . enumerate ( )
662- . map ( |( i, field) | {
663- // strip alias, as they should not be part of filters
664- let expr = match & projection. expr [ i] {
665- Expr :: Alias ( Alias { expr, .. } ) => expr. as_ref ( ) . clone ( ) ,
666- expr => expr. clone ( ) ,
667- } ;
668-
669- ( field. qualified_name ( ) , expr)
670- } )
671- . collect :: < HashMap < _ , _ > > ( ) ;
659+ let ( volatile_map, non_volatile_map) : ( HashMap < _ , _ > , HashMap < _ , _ > ) =
660+ projection
661+ . schema
662+ . fields ( )
663+ . iter ( )
664+ . enumerate ( )
665+ . map ( |( i, field) | {
666+ // strip alias, as they should not be part of filters
667+ let expr = match & projection. expr [ i] {
668+ Expr :: Alias ( Alias { expr, .. } ) => expr. as_ref ( ) . clone ( ) ,
669+ expr => expr. clone ( ) ,
670+ } ;
671+
672+ ( field. qualified_name ( ) , expr)
673+ } )
674+ . partition ( |( _, value) | is_volatile_expression ( value) ) ;
672675
673- // re-write all filters based on this projection
674- // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
675- let new_filter = LogicalPlan :: Filter ( Filter :: try_new (
676- replace_cols_by_name ( filter. predicate . clone ( ) , & replace_map) ?,
677- projection. input . clone ( ) ,
678- ) ?) ;
676+ let mut push_predicates = vec ! [ ] ;
677+ let mut keep_predicates = vec ! [ ] ;
678+ for expr in split_conjunction_owned ( filter. predicate . clone ( ) ) . into_iter ( )
679+ {
680+ if contain ( & expr, & volatile_map) {
681+ keep_predicates. push ( expr) ;
682+ } else {
683+ push_predicates. push ( expr) ;
684+ }
685+ }
679686
680- child_plan. with_new_inputs ( & [ new_filter] ) ?
687+ match conjunction ( push_predicates) {
688+ Some ( expr) => {
689+ // re-write all filters based on this projection
690+ // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
691+ let new_filter = LogicalPlan :: Filter ( Filter :: try_new (
692+ replace_cols_by_name ( expr, & non_volatile_map) ?,
693+ projection. input . clone ( ) ,
694+ ) ?) ;
695+
696+ match conjunction ( keep_predicates) {
697+ None => child_plan. with_new_inputs ( & [ new_filter] ) ?,
698+ Some ( keep_predicate) => {
699+ let child_plan =
700+ child_plan. with_new_inputs ( & [ new_filter] ) ?;
701+ LogicalPlan :: Filter ( Filter :: try_new (
702+ keep_predicate,
703+ Arc :: new ( child_plan) ,
704+ ) ?)
705+ }
706+ }
707+ }
708+ None => return Ok ( None ) ,
709+ }
681710 }
682711 LogicalPlan :: Union ( union) => {
683712 let mut inputs = Vec :: with_capacity ( union. inputs . len ( ) ) ;
@@ -881,6 +910,42 @@ pub fn replace_cols_by_name(
881910 } )
882911}
883912
913+ /// check whether the expression is volatile predicates
914+ fn is_volatile_expression ( e : & Expr ) -> bool {
915+ let mut is_volatile = false ;
916+ e. apply ( & mut |expr| {
917+ Ok ( match expr {
918+ Expr :: ScalarFunction ( f) if f. fun . volatility ( ) == Volatility :: Volatile => {
919+ is_volatile = true ;
920+ VisitRecursion :: Stop
921+ }
922+ _ => VisitRecursion :: Continue ,
923+ } )
924+ } )
925+ . unwrap ( ) ;
926+ is_volatile
927+ }
928+
929+ /// check whether the expression uses the columns in `check_map`.
930+ fn contain ( e : & Expr , check_map : & HashMap < String , Expr > ) -> bool {
931+ let mut is_contain = false ;
932+ e. apply ( & mut |expr| {
933+ Ok ( if let Expr :: Column ( c) = & expr {
934+ match check_map. get ( & c. flat_name ( ) ) {
935+ Some ( _) => {
936+ is_contain = true ;
937+ VisitRecursion :: Stop
938+ }
939+ None => VisitRecursion :: Continue ,
940+ }
941+ } else {
942+ VisitRecursion :: Continue
943+ } )
944+ } )
945+ . unwrap ( ) ;
946+ is_contain
947+ }
948+
884949#[ cfg( test) ]
885950mod tests {
886951 use super :: * ;
@@ -893,9 +958,9 @@ mod tests {
893958 use datafusion_common:: { DFSchema , DFSchemaRef } ;
894959 use datafusion_expr:: logical_plan:: table_scan;
895960 use datafusion_expr:: {
896- and, col, in_list, in_subquery, lit, logical_plan:: JoinType , or, sum , BinaryExpr ,
897- Expr , Extension , LogicalPlanBuilder , Operator , TableSource , TableType ,
898- UserDefinedLogicalNodeCore ,
961+ and, col, in_list, in_subquery, lit, logical_plan:: JoinType , or, random , sum ,
962+ BinaryExpr , Expr , Extension , LogicalPlanBuilder , Operator , TableSource ,
963+ TableType , UserDefinedLogicalNodeCore ,
899964 } ;
900965 use std:: fmt:: { Debug , Formatter } ;
901966 use std:: sync:: Arc ;
@@ -2712,4 +2777,79 @@ Projection: a, b
27122777 \n TableScan: test2";
27132778 assert_optimized_plan_eq ( & plan, expected)
27142779 }
2780+
2781+ #[ test]
2782+ fn test_push_down_volatile_function_in_aggregate ( ) -> Result < ( ) > {
2783+ // SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
2784+ let table_scan = test_table_scan_with_name ( "test1" ) ?;
2785+ let plan = LogicalPlanBuilder :: from ( table_scan)
2786+ . aggregate ( vec ! [ col( "a" ) ] , vec ! [ sum( col( "b" ) ) ] ) ?
2787+ . project ( vec ! [
2788+ col( "a" ) ,
2789+ sum( col( "b" ) ) ,
2790+ add( random( ) , lit( 1 ) ) . alias( "r" ) ,
2791+ ] ) ?
2792+ . alias ( "t" ) ?
2793+ . filter ( col ( "t.a" ) . gt ( lit ( 5 ) ) . and ( col ( "t.r" ) . gt ( lit ( 0.5 ) ) ) ) ?
2794+ . project ( vec ! [ col( "t.a" ) , col( "t.r" ) ] ) ?
2795+ . build ( ) ?;
2796+
2797+ let expected_before = "Projection: t.a, t.r\
2798+ \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\
2799+ \n SubqueryAlias: t\
2800+ \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
2801+ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
2802+ \n TableScan: test1";
2803+ assert_eq ! ( format!( "{plan:?}" ) , expected_before) ;
2804+
2805+ let expected_after = "Projection: t.a, t.r\
2806+ \n SubqueryAlias: t\
2807+ \n Filter: r > Float64(0.5)\
2808+ \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
2809+ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
2810+ \n TableScan: test1, full_filters=[test1.a > Int32(5)]";
2811+ assert_optimized_plan_eq ( & plan, expected_after)
2812+ }
2813+
2814+ #[ test]
2815+ fn test_push_down_volatile_function_in_join ( ) -> Result < ( ) > {
2816+ // SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
2817+ let table_scan = test_table_scan_with_name ( "test1" ) ?;
2818+ let left = LogicalPlanBuilder :: from ( table_scan) . build ( ) ?;
2819+ let right_table_scan = test_table_scan_with_name ( "test2" ) ?;
2820+ let right = LogicalPlanBuilder :: from ( right_table_scan) . build ( ) ?;
2821+ let plan = LogicalPlanBuilder :: from ( left)
2822+ . join (
2823+ right,
2824+ JoinType :: Inner ,
2825+ (
2826+ vec ! [ Column :: from_qualified_name( "test1.a" ) ] ,
2827+ vec ! [ Column :: from_qualified_name( "test2.a" ) ] ,
2828+ ) ,
2829+ None ,
2830+ ) ?
2831+ . project ( vec ! [ col( "test1.a" ) . alias( "a" ) , random( ) . alias( "r" ) ] ) ?
2832+ . alias ( "t" ) ?
2833+ . filter ( col ( "t.r" ) . gt ( lit ( 0.8 ) ) ) ?
2834+ . project ( vec ! [ col( "t.a" ) , col( "t.r" ) ] ) ?
2835+ . build ( ) ?;
2836+
2837+ let expected_before = "Projection: t.a, t.r\
2838+ \n Filter: t.r > Float64(0.8)\
2839+ \n SubqueryAlias: t\
2840+ \n Projection: test1.a AS a, random() AS r\
2841+ \n Inner Join: test1.a = test2.a\
2842+ \n TableScan: test1\
2843+ \n TableScan: test2";
2844+ assert_eq ! ( format!( "{plan:?}" ) , expected_before) ;
2845+
2846+ let expected = "Projection: t.a, t.r\
2847+ \n SubqueryAlias: t\
2848+ \n Filter: r > Float64(0.8)\
2849+ \n Projection: test1.a AS a, random() AS r\
2850+ \n Inner Join: test1.a = test2.a\
2851+ \n TableScan: test1\
2852+ \n TableScan: test2";
2853+ assert_optimized_plan_eq ( & plan, expected)
2854+ }
27152855}
0 commit comments