Skip to content

Commit 1ed3963

Browse files
authored
Support unparsing plans with both Aggregation and Window functions (#35)
1 parent 34f7d90 commit 1ed3963

File tree

3 files changed

+105
-44
lines changed

3 files changed

+105
-44
lines changed

datafusion/sql/src/unparser/plan.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ use super::{
3636
rewrite_plan_for_sort_on_non_projected_fields,
3737
subquery_alias_inner_query_and_columns,
3838
},
39-
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
39+
utils::{
40+
find_agg_node_within_select, find_window_nodes_within_select,
41+
unproject_window_exprs,
42+
},
4043
Unparser,
4144
};
4245

@@ -170,13 +173,17 @@ impl Unparser<'_> {
170173
p: &Projection,
171174
select: &mut SelectBuilder,
172175
) -> Result<()> {
173-
match find_agg_node_within_select(plan, None, true) {
174-
Some(AggVariant::Aggregate(agg)) => {
176+
match (
177+
find_agg_node_within_select(plan, true),
178+
find_window_nodes_within_select(plan, None, true),
179+
) {
180+
(Some(agg), window) => {
181+
let window_option = window.as_deref();
175182
let items = p
176183
.expr
177184
.iter()
178185
.map(|proj_expr| {
179-
let unproj = unproject_agg_exprs(proj_expr, agg)?;
186+
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
180187
self.select_item_to_sql(&unproj)
181188
})
182189
.collect::<Result<Vec<_>>>()?;
@@ -190,7 +197,7 @@ impl Unparser<'_> {
190197
vec![],
191198
));
192199
}
193-
Some(AggVariant::Window(window)) => {
200+
(None, Some(window)) => {
194201
let items = p
195202
.expr
196203
.iter()
@@ -202,7 +209,7 @@ impl Unparser<'_> {
202209

203210
select.projection(items);
204211
}
205-
None => {
212+
_ => {
206213
let items = p
207214
.expr
208215
.iter()
@@ -285,10 +292,10 @@ impl Unparser<'_> {
285292
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
286293
}
287294
LogicalPlan::Filter(filter) => {
288-
if let Some(AggVariant::Aggregate(agg)) =
289-
find_agg_node_within_select(plan, None, select.already_projected())
295+
if let Some(agg) =
296+
find_agg_node_within_select(plan, select.already_projected())
290297
{
291-
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
298+
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
292299
let filter_expr = self.expr_to_sql(&unprojected)?;
293300
select.having(Some(filter_expr));
294301
} else {

datafusion/sql/src/unparser/utils.rs

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,81 @@
1818
use datafusion_common::{
1919
internal_err,
2020
tree_node::{Transformed, TreeNode},
21-
Result,
21+
Column, Result,
2222
};
2323
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Window};
2424

25-
/// One of the possible aggregation plans which can be found within a single select query.
26-
pub(crate) enum AggVariant<'a> {
27-
Aggregate(&'a Aggregate),
28-
Window(Vec<&'a Window>),
25+
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
26+
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
27+
/// If an Aggregate or node is not found prior to this or at all before reaching the end
28+
/// of the tree, None is returned.
29+
pub(crate) fn find_agg_node_within_select(
30+
plan: &LogicalPlan,
31+
already_projected: bool,
32+
) -> Option<&Aggregate> {
33+
// Note that none of the nodes that have a corresponding node can have more
34+
// than 1 input node. E.g. Projection / Filter always have 1 input node.
35+
let input = plan.inputs();
36+
let input = if input.len() > 1 {
37+
return None;
38+
} else {
39+
input.first()?
40+
};
41+
// Agg nodes explicitly return immediately with a single node
42+
if let LogicalPlan::Aggregate(agg) = input {
43+
Some(agg)
44+
} else if let LogicalPlan::TableScan(_) = input {
45+
None
46+
} else if let LogicalPlan::Projection(_) = input {
47+
if already_projected {
48+
None
49+
} else {
50+
find_agg_node_within_select(input, true)
51+
}
52+
} else {
53+
find_agg_node_within_select(input, already_projected)
54+
}
2955
}
3056

31-
/// Recursively searches children of [LogicalPlan] to find an Aggregate or window node if one exists
57+
/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
3258
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
33-
/// If an Aggregate or window node is not found prior to this or at all before reaching the end
34-
/// of the tree, None is returned. It is assumed that a Window and Aggregate node cannot both
35-
/// be found in a single select query.
36-
pub(crate) fn find_agg_node_within_select<'a>(
59+
/// If Window node is not found prior to this or at all before reaching the end
60+
/// of the tree, None is returned.
61+
pub(crate) fn find_window_nodes_within_select<'a>(
3762
plan: &'a LogicalPlan,
38-
mut prev_windows: Option<AggVariant<'a>>,
63+
mut prev_windows: Option<Vec<&'a Window>>,
3964
already_projected: bool,
40-
) -> Option<AggVariant<'a>> {
41-
// Note that none of the nodes that have a corresponding agg node can have more
65+
) -> Option<Vec<&'a Window>> {
66+
// Note that none of the nodes that have a corresponding node can have more
4267
// than 1 input node. E.g. Projection / Filter always have 1 input node.
4368
let input = plan.inputs();
4469
let input = if input.len() > 1 {
45-
return None;
70+
return prev_windows;
4671
} else {
4772
input.first()?
4873
};
4974

50-
// Agg nodes explicitly return immediately with a single node
5175
// Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
5276
match input {
53-
LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)),
5477
LogicalPlan::Window(window) => {
5578
prev_windows = match &mut prev_windows {
56-
Some(AggVariant::Window(windows)) => {
79+
Some(windows) => {
5780
windows.push(window);
5881
prev_windows
5982
}
60-
_ => Some(AggVariant::Window(vec![window])),
83+
_ => Some(vec![window]),
6184
};
62-
find_agg_node_within_select(input, prev_windows, already_projected)
85+
find_window_nodes_within_select(input, prev_windows, already_projected)
6386
}
6487
LogicalPlan::Projection(_) => {
6588
if already_projected {
6689
prev_windows
6790
} else {
68-
find_agg_node_within_select(input, prev_windows, true)
91+
find_window_nodes_within_select(input, prev_windows, true)
6992
}
7093
}
7194
LogicalPlan::TableScan(_) => prev_windows,
72-
_ => find_agg_node_within_select(input, prev_windows, already_projected),
95+
_ => find_window_nodes_within_select(input, prev_windows, already_projected),
7396
}
7497
}
7598

@@ -78,19 +101,30 @@ pub(crate) fn find_agg_node_within_select<'a>(
78101
///
79102
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
80103
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
81-
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
104+
pub(crate) fn unproject_agg_exprs(
105+
expr: &Expr,
106+
agg: &Aggregate,
107+
windows: Option<&[&Window]>,
108+
) -> Result<Expr> {
82109
expr.clone()
83110
.transform(|sub_expr| {
84111
if let Expr::Column(c) = sub_expr {
85-
// find the column in the agg schema
86-
if let Ok(n) = agg.schema.index_of_column(&c) {
87-
let unprojected_expr = agg
88-
.group_expr
89-
.iter()
90-
.chain(agg.aggr_expr.iter())
91-
.nth(n)
92-
.unwrap();
112+
if let Some(unprojected_expr) = find_agg_expr(agg, &c) {
93113
Ok(Transformed::yes(unprojected_expr.clone()))
114+
} else if let Some(mut unprojected_expr) =
115+
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
116+
{
117+
if let Expr::WindowFunction(func) = &mut unprojected_expr {
118+
// Window function can contain aggregation column, for ex 'avg(sum(ss_sales_price)) over ..' that needs to be unprojected
119+
for arg in &mut func.args {
120+
if let Expr::Column(c) = arg {
121+
if let Some(expr) = find_agg_expr(agg, c) {
122+
*arg = expr.clone();
123+
}
124+
}
125+
}
126+
}
127+
Ok(Transformed::yes(unprojected_expr))
94128
} else {
95129
internal_err!(
96130
"Tried to unproject agg expr not found in provided Aggregate!"
@@ -112,11 +146,7 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
112146
expr.clone()
113147
.transform(|sub_expr| {
114148
if let Expr::Column(c) = sub_expr {
115-
if let Some(unproj) = windows
116-
.iter()
117-
.flat_map(|w| w.window_expr.iter())
118-
.find(|window_expr| window_expr.schema_name().to_string() == c.name)
119-
{
149+
if let Some(unproj) = find_window_expr(windows, &c.name) {
120150
Ok(Transformed::yes(unproj.clone()))
121151
} else {
122152
Ok(Transformed::no(Expr::Column(c)))
@@ -127,3 +157,21 @@ pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result
127157
})
128158
.map(|e| e.data)
129159
}
160+
161+
fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Option<&'a Expr> {
162+
if let Ok(index) = agg.schema.index_of_column(column) {
163+
agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)
164+
} else {
165+
None
166+
}
167+
}
168+
169+
fn find_window_expr<'a>(
170+
windows: &'a [&'a Window],
171+
column_name: &'a str,
172+
) -> Option<&'a Expr> {
173+
windows
174+
.iter()
175+
.flat_map(|w| w.window_expr.iter())
176+
.find(|expr| expr.schema_name().to_string() == column_name)
177+
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ fn roundtrip_statement() -> Result<()> {
146146
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
147147
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
148148
"WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)",
149+
r#"SELECT id, first_name,
150+
SUM(id) AS total_sum,
151+
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
152+
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
153+
FROM person GROUP BY id, first_name"#,
149154
];
150155

151156
// For each test sql string, we transform as follows:
@@ -161,6 +166,7 @@ fn roundtrip_statement() -> Result<()> {
161166
let state = MockSessionState::default()
162167
.with_aggregate_function(sum_udaf())
163168
.with_aggregate_function(count_udaf())
169+
.with_aggregate_function(max_udaf())
164170
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
165171
let context = MockContextProvider { state };
166172
let sql_to_rel = SqlToRel::new(&context);

0 commit comments

Comments
 (0)