1919
2020use std:: sync:: Arc ;
2121
22- use datafusion_common:: tree_node:: Transformed ;
22+ use datafusion_common:: tree_node:: { Transformed , TreeNode } ;
2323use datafusion_common:: { DFSchema , DFSchemaRef , DataFusionError , Result } ;
2424use datafusion_expr:: execution_props:: ExecutionProps ;
2525use datafusion_expr:: logical_plan:: LogicalPlan ;
2626use datafusion_expr:: simplify:: SimplifyContext ;
2727use datafusion_expr:: utils:: merge_schema;
28+ use datafusion_expr:: Expr ;
2829
2930use crate :: optimizer:: ApplyOrder ;
3031use crate :: utils:: NamePreserver ;
@@ -122,14 +123,21 @@ impl SimplifyExpressions {
122123
123124 // Preserve expression names to avoid changing the schema of the plan.
124125 let name_preserver = NamePreserver :: new ( & plan) ;
125- plan. map_expressions ( |e| {
126- let original_name = name_preserver. save ( & e) ;
127- let new_e = simplifier
128- . simplify ( e)
129- . map ( |expr| original_name. restore ( expr) ) ?;
126+ let mut rewrite_expr = |expr : Expr | {
127+ let name = name_preserver. save ( & expr) ;
128+ let expr = simplifier. simplify ( expr) ?;
130129 // TODO it would be nice to have a way to know if the expression was simplified
131130 // or not. For now conservatively return Transformed::yes
132- Ok ( Transformed :: yes ( new_e) )
131+ Ok ( Transformed :: yes ( name. restore ( expr) ) )
132+ } ;
133+
134+ plan. map_expressions ( |expr| {
135+ // Preserve the aliasing of grouping sets.
136+ if let Expr :: GroupingSet ( _) = & expr {
137+ expr. map_children ( & mut rewrite_expr)
138+ } else {
139+ rewrite_expr ( expr)
140+ }
133141 } )
134142 }
135143}
@@ -151,11 +159,7 @@ mod tests {
151159 use crate :: optimizer:: Optimizer ;
152160 use datafusion_expr:: logical_plan:: builder:: table_scan_with_filters;
153161 use datafusion_expr:: logical_plan:: table_scan;
154- use datafusion_expr:: {
155- and, binary_expr, col, lit, logical_plan:: builder:: LogicalPlanBuilder , Expr ,
156- ExprSchemable , JoinType ,
157- } ;
158- use datafusion_expr:: { or, BinaryExpr , Cast , Operator } ;
162+ use datafusion_expr:: * ;
159163 use datafusion_functions_aggregate:: expr_fn:: { max, min} ;
160164
161165 use crate :: test:: { assert_fields_eq, test_table_scan_with_name} ;
@@ -743,4 +747,24 @@ mod tests {
743747
744748 assert_optimized_plan_eq ( plan, expected)
745749 }
750+
751+ #[ test]
752+ fn simplify_grouping_sets ( ) -> Result < ( ) > {
753+ let table_scan = test_table_scan ( ) ;
754+ let plan = LogicalPlanBuilder :: from ( table_scan)
755+ . aggregate (
756+ [ grouping_set ( vec ! [
757+ vec![ ( lit( 42 ) . alias( "prev" ) + lit( 1 ) ) . alias( "age" ) , col( "a" ) ] ,
758+ vec![ col( "a" ) . or( col( "b" ) ) . and( lit( 1 ) . lt( lit( 0 ) ) ) . alias( "cond" ) ] ,
759+ vec![ col( "d" ) . alias( "e" ) , ( lit( 1 ) + lit( 2 ) ) ] ,
760+ ] ) ] ,
761+ [ ] as [ Expr ; 0 ] ,
762+ ) ?
763+ . build ( ) ?;
764+
765+ let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\
766+ \n TableScan: test";
767+
768+ assert_optimized_plan_eq ( plan, expected)
769+ }
746770}
0 commit comments