1818use datafusion_common:: {
1919 internal_err,
2020 tree_node:: { Transformed , TreeNode } ,
21- Result ,
21+ Column , Result ,
2222} ;
2323use datafusion_expr:: { Aggregate , Expr , LogicalPlan , Window } ;
2424
25- /// One of the possible aggregation plans which can be found within a single select query.
26- pub ( crate ) enum AggVariant < ' a > {
27- Aggregate ( & ' a Aggregate ) ,
28- Window ( Vec < & ' a Window > ) ,
25+ /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
26+ /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
27+ /// If an Aggregate or node is not found prior to this or at all before reaching the end
28+ /// of the tree, None is returned.
29+ pub ( crate ) fn find_agg_node_within_select (
30+ plan : & LogicalPlan ,
31+ already_projected : bool ,
32+ ) -> Option < & Aggregate > {
33+ // Note that none of the nodes that have a corresponding node can have more
34+ // than 1 input node. E.g. Projection / Filter always have 1 input node.
35+ let input = plan. inputs ( ) ;
36+ let input = if input. len ( ) > 1 {
37+ return None ;
38+ } else {
39+ input. first ( ) ?
40+ } ;
41+ // Agg nodes explicitly return immediately with a single node
42+ if let LogicalPlan :: Aggregate ( agg) = input {
43+ Some ( agg)
44+ } else if let LogicalPlan :: TableScan ( _) = input {
45+ None
46+ } else if let LogicalPlan :: Projection ( _) = input {
47+ if already_projected {
48+ None
49+ } else {
50+ find_agg_node_within_select ( input, true )
51+ }
52+ } else {
53+ find_agg_node_within_select ( input, already_projected)
54+ }
2955}
3056
31- /// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
57+ /// Recursively searches children of [LogicalPlan] to find Window nodes if exist
3258/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
33- /// If an Aggregate or window node is not found prior to this or at all before reaching the end
34- /// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both
35- /// be found in a single select query.
36- pub ( crate ) fn find_agg_node_within_select < ' a > (
59+ /// If Window node is not found prior to this or at all before reaching the end
60+ /// of the tree, None is returned.
61+ pub ( crate ) fn find_window_nodes_within_select < ' a > (
3762 plan : & ' a LogicalPlan ,
38- mut prev_windows : Option < AggVariant < ' a > > ,
63+ mut prev_windows : Option < Vec < & ' a Window > > ,
3964 already_projected : bool ,
40- ) -> Option < AggVariant < ' a > > {
41- // Note that none of the nodes that have a corresponding agg node can have more
65+ ) -> Option < Vec < & ' a Window > > {
66+ // Note that none of the nodes that have a corresponding node can have more
4267 // than 1 input node. E.g. Projection / Filter always have 1 input node.
4368 let input = plan. inputs ( ) ;
4469 let input = if input. len ( ) > 1 {
45- return None ;
70+ return prev_windows ;
4671 } else {
4772 input. first ( ) ?
4873 } ;
4974
50- // Agg nodes explicitly return immediately with a single node
5175 // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
5276 match input {
53- LogicalPlan :: Aggregate ( agg) => Some ( AggVariant :: Aggregate ( agg) ) ,
5477 LogicalPlan :: Window ( window) => {
5578 prev_windows = match & mut prev_windows {
56- Some ( AggVariant :: Window ( windows) ) => {
79+ Some ( windows) => {
5780 windows. push ( window) ;
5881 prev_windows
5982 }
60- _ => Some ( AggVariant :: Window ( vec ! [ window] ) ) ,
83+ _ => Some ( vec ! [ window] ) ,
6184 } ;
62- find_agg_node_within_select ( input, prev_windows, already_projected)
85+ find_window_nodes_within_select ( input, prev_windows, already_projected)
6386 }
6487 LogicalPlan :: Projection ( _) => {
6588 if already_projected {
6689 prev_windows
6790 } else {
68- find_agg_node_within_select ( input, prev_windows, true )
91+ find_window_nodes_within_select ( input, prev_windows, true )
6992 }
7093 }
7194 LogicalPlan :: TableScan ( _) => prev_windows,
72- _ => find_agg_node_within_select ( input, prev_windows, already_projected) ,
95+ _ => find_window_nodes_within_select ( input, prev_windows, already_projected) ,
7396 }
7497}
7598
@@ -78,19 +101,30 @@ pub(crate) fn find_agg_node_within_select<'a>(
78101///
79102/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
80103/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
81- pub ( crate ) fn unproject_agg_exprs ( expr : & Expr , agg : & Aggregate ) -> Result < Expr > {
104+ pub ( crate ) fn unproject_agg_exprs (
105+ expr : & Expr ,
106+ agg : & Aggregate ,
107+ windows : Option < & [ & Window ] > ,
108+ ) -> Result < Expr > {
82109 expr. clone ( )
83110 . transform ( |sub_expr| {
84111 if let Expr :: Column ( c) = sub_expr {
85- // find the column in the agg schema
86- if let Ok ( n) = agg. schema . index_of_column ( & c) {
87- let unprojected_expr = agg
88- . group_expr
89- . iter ( )
90- . chain ( agg. aggr_expr . iter ( ) )
91- . nth ( n)
92- . unwrap ( ) ;
112+ if let Some ( unprojected_expr) = find_agg_expr ( agg, & c) {
93113 Ok ( Transformed :: yes ( unprojected_expr. clone ( ) ) )
114+ } else if let Some ( mut unprojected_expr) =
115+ windows. and_then ( |w| find_window_expr ( w, & c. name ) . cloned ( ) )
116+ {
117+ if let Expr :: WindowFunction ( func) = & mut unprojected_expr {
118+ // Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected
119+ for arg in & mut func. args {
120+ if let Expr :: Column ( c) = arg {
121+ if let Some ( expr) = find_agg_expr ( agg, c) {
122+ * arg = expr. clone ( ) ;
123+ }
124+ }
125+ }
126+ }
127+ Ok ( Transformed :: yes ( unprojected_expr) )
94128 } else {
95129 internal_err ! (
96130 "Tried to unproject agg expr not found in provided Aggregate!"
@@ -112,11 +146,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
112146 expr. clone ( )
113147 . transform ( |sub_expr| {
114148 if let Expr :: Column ( c) = sub_expr {
115- if let Some ( unproj) = windows
116- . iter ( )
117- . flat_map ( |w| w. window_expr . iter ( ) )
118- . find ( |window_expr| window_expr. schema_name ( ) . to_string ( ) == c. name )
119- {
149+ if let Some ( unproj) = find_window_expr ( windows, & c. name ) {
120150 Ok ( Transformed :: yes ( unproj. clone ( ) ) )
121151 } else {
122152 Ok ( Transformed :: no ( Expr :: Column ( c) ) )
@@ -127,3 +157,21 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
127157 } )
128158 . map ( |e| e. data )
129159}
160+
161+ fn find_agg_expr < ' a > ( agg : & ' a Aggregate , column : & Column ) -> Option < & ' a Expr > {
162+ if let Ok ( index) = agg. schema . index_of_column ( column) {
163+ agg. group_expr . iter ( ) . chain ( agg. aggr_expr . iter ( ) ) . nth ( index)
164+ } else {
165+ None
166+ }
167+ }
168+
169+ fn find_window_expr < ' a > (
170+ windows : & ' a [ & ' a Window ] ,
171+ column_name : & ' a str ,
172+ ) -> Option < & ' a Expr > {
173+ windows
174+ . iter ( )
175+ . flat_map ( |w| w. window_expr . iter ( ) )
176+ . find ( |expr| expr. schema_name ( ) . to_string ( ) == column_name)
177+ }
0 commit comments