Skip to content

Commit 443fdc1

Browse files
committed
support pivot
1 parent adb7258 commit 443fdc1

File tree

3 files changed

+276
-4
lines changed

3 files changed

+276
-4
lines changed

datafusion/sql/src/query.rs

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@ use std::sync::Arc;
2020
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
2121

2222
use crate::stack::StackGuard;
23+
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
2324
use datafusion_common::{not_impl_err, Constraints, DFSchema, Result};
24-
use datafusion_expr::expr::{Sort, WildcardOptions};
25+
use datafusion_expr::expr::{AggregateFunction, Sort, WildcardOptions};
2526

2627
use datafusion_expr::select_expr::SelectExpr;
2728
use datafusion_expr::{
28-
CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder,
29+
case, col, lit, CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan,
30+
LogicalPlanBuilder,
2931
};
3032
use sqlparser::ast::{
3133
Expr as SQLExpr, ExprWithAliasAndOrderBy, Ident, LimitClause, Offset, OffsetRows,
32-
OrderBy, OrderByExpr, OrderByKind, PipeOperator, Query, SelectInto, SetExpr,
33-
SetOperator, SetQuantifier, TableAlias,
34+
OrderBy, OrderByExpr, OrderByKind, PipeOperator, PivotValueSource, Query, SelectInto,
35+
SetExpr, SetOperator, SetQuantifier, TableAlias,
3436
};
3537
use sqlparser::tokenizer::Span;
3638

@@ -200,6 +202,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
200202
PipeOperator::Join(join) => {
201203
self.parse_relation_join(plan, join, planner_context)
202204
}
205+
PipeOperator::Pivot {
206+
aggregate_functions,
207+
value_column,
208+
value_source,
209+
alias,
210+
} => self.pipe_operator_pivot(
211+
plan,
212+
aggregate_functions,
213+
value_column,
214+
value_source,
215+
alias,
216+
planner_context,
217+
),
203218

204219
x => not_impl_err!("`{x}` pipe operator is not supported yet"),
205220
}
@@ -376,6 +391,143 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
376391
.build()
377392
}
378393

394+
/// Handle PIVOT pipe operator
395+
fn pipe_operator_pivot(
396+
&self,
397+
plan: LogicalPlan,
398+
aggregate_functions: Vec<sqlparser::ast::ExprWithAlias>,
399+
value_column: Vec<Ident>,
400+
value_source: PivotValueSource,
401+
alias: Option<Ident>,
402+
planner_context: &mut PlannerContext,
403+
) -> Result<LogicalPlan> {
404+
// Extract pivot values from the value source
405+
let pivot_values = if let PivotValueSource::List(values) = value_source {
406+
values
407+
} else {
408+
return not_impl_err!(
409+
"Only static pivot value lists are supported currently"
410+
);
411+
};
412+
413+
// Convert pivot column to DataFusion expression
414+
if value_column.len() != 1 {
415+
return not_impl_err!("Multi-column pivot is not supported yet");
416+
}
417+
let pivot_col_name = &value_column[0].value;
418+
let pivot_col_expr = col(pivot_col_name);
419+
420+
let input_schema = plan.schema();
421+
422+
// Convert sql to DF exprs
423+
let aggregate_functions = aggregate_functions
424+
.into_iter()
425+
.map(|f| self.sql_to_expr_with_alias(f, input_schema, planner_context))
426+
.collect::<Result<Vec<_>, _>>()?;
427+
428+
// Convert aggregate functions to logical expressions to extract measure columns
429+
let mut measure_columns = std::collections::HashSet::new();
430+
for agg_func_with_alias in &aggregate_functions {
431+
agg_func_with_alias.apply(|e| {
432+
if let Expr::Column(col) = e {
433+
measure_columns.insert(col.name.clone());
434+
};
435+
Ok(TreeNodeRecursion::Continue)
436+
})?;
437+
}
438+
439+
// Get all column names from the input plan to determine group-by columns.
440+
// Add all columns except the pivot column and measure columns to group by
441+
let mut group_by_cols = Vec::new();
442+
for field in input_schema.fields() {
443+
let col_name = field.name();
444+
if col_name != pivot_col_name && !measure_columns.contains(col_name) {
445+
group_by_cols.push(col(col_name));
446+
}
447+
}
448+
449+
// Create aggregate expressions for each pivot value
450+
let mut aggr_exprs = Vec::new();
451+
452+
// For each pivot value and aggregate function combination, create a conditional aggregate
453+
// Process pivot values first to get the desired column order
454+
for pivot_value in pivot_values {
455+
let pivot_value_expr = self.sql_to_expr(
456+
pivot_value.expr.clone(),
457+
input_schema,
458+
planner_context,
459+
)?;
460+
for agg_func_with_alias in &aggregate_functions {
461+
let (alias_name, mut agg_fn) = match agg_func_with_alias {
462+
Expr::Alias(alias) => match *alias.expr.clone() {
463+
Expr::Alias(inner_alias) => {
464+
let Expr::AggregateFunction(
465+
agg_func @ AggregateFunction { .. },
466+
) = *inner_alias.expr.clone()
467+
else {
468+
return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions");
469+
};
470+
(Some(alias.name.clone()), agg_func)
471+
}
472+
Expr::AggregateFunction(agg_func @ AggregateFunction { .. }) => {
473+
(Some(alias.name.clone()), agg_func)
474+
}
475+
_ => {
476+
return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions");
477+
}
478+
},
479+
Expr::AggregateFunction(agg_func) => (None, agg_func.clone()),
480+
_ => {
481+
return not_impl_err!("Expected aggregate function");
482+
}
483+
};
484+
485+
if agg_fn.params.args.len() != 1 {
486+
return not_impl_err!(
487+
"Only exactly one aggregate function argument is supported"
488+
);
489+
}
490+
let arg = &mut agg_fn.params.args[0];
491+
*arg = case(pivot_col_expr.clone())
492+
.when(pivot_value_expr.clone(), arg.clone())
493+
.otherwise(lit(datafusion_common::ScalarValue::Null))?;
494+
495+
let agg_expr = Expr::AggregateFunction(agg_fn);
496+
let aggr_func_alias = alias_name.unwrap_or(agg_expr.name_for_alias()?);
497+
498+
let pivot_value_name = if let Some(alias) = &pivot_value.alias {
499+
alias.value.clone()
500+
} else {
501+
// Use the pivot value as column name, stripping quotes
502+
pivot_value.expr.to_string().trim_matches('\'').to_string()
503+
};
504+
505+
aggr_exprs.push(
506+
// Give unique name based on pivot column name
507+
agg_expr.alias(format!("{aggr_func_alias}_{pivot_value_name}")),
508+
);
509+
}
510+
}
511+
512+
// Create the aggregate logical plan
513+
let result_plan = LogicalPlanBuilder::from(plan)
514+
.aggregate(group_by_cols, aggr_exprs)?
515+
.build()?;
516+
517+
// Apply table alias if provided
518+
if let Some(table_alias) = alias {
519+
self.apply_table_alias(
520+
result_plan,
521+
TableAlias {
522+
name: table_alias,
523+
columns: vec![],
524+
},
525+
)
526+
} else {
527+
Ok(result_plan)
528+
}
529+
}
530+
379531
/// Wrap the logical plan in a `SelectInto`
380532
fn select_into(
381533
&self,

datafusion/sqllogictest/test_files/pipe_operator.slt

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,95 @@ query TII
195195
----
196196
apples 2 123
197197
bananas 5 NULL
198+
199+
# PIVOT pipe
200+
201+
statement ok
202+
CREATE TABLE pipe_test(
203+
product VARCHAR,
204+
sales INT,
205+
quarter VARCHAR,
206+
year INT
207+
) AS VALUES
208+
('Kale', 51, 'Q1', 2020),
209+
('Kale', 23, 'Q2', 2020),
210+
('Kale', 45, 'Q3', 2020),
211+
('Kale', 3, 'Q4', 2020),
212+
('Kale', 70, 'Q1', 2021),
213+
('Kale', 85, 'Q2', 2021),
214+
('Apple', 77, 'Q1', 2020),
215+
('Apple', 0, 'Q2', 2020),
216+
('Apple', 1, 'Q1', 2021)
217+
;
218+
219+
query TIIIII rowsort
220+
SELECT * FROM pipe_test
221+
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4'));
222+
----
223+
Apple 2020 77 0 NULL NULL
224+
Apple 2021 1 NULL NULL NULL
225+
Kale 2020 51 23 45 3
226+
Kale 2021 70 85 NULL NULL
227+
228+
query TIIII rowsort
229+
SELECT * FROM pipe_test
230+
|> select product, sales, quarter
231+
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4'));
232+
----
233+
Apple 78 0 NULL NULL
234+
Kale 121 108 45 3
235+
236+
query TIII rowsort
237+
SELECT * FROM pipe_test
238+
|> select product, sales, quarter
239+
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3'));
240+
----
241+
Apple 78 0 NULL
242+
Kale 121 108 45
243+
244+
query TIIII rowsort
245+
SELECT * FROM pipe_test
246+
|> select product, sales, quarter
247+
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2'));
248+
----
249+
Apple 78 2 0 1
250+
Kale 121 2 108 2
251+
252+
253+
query TT
254+
EXPLAIN SELECT * FROM pipe_test
255+
|> select product, sales, quarter
256+
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2'));
257+
----
258+
logical_plan
259+
01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(CAST(CASE pipe_test.quarter WHEN Utf8View("Q1") THEN pipe_test.sales ELSE Int32(NULL) END AS Int64)) AS total_sales_Q1, count(CASE pipe_test.quarter WHEN Utf8View("Q1") THEN Int64(1) ELSE Int64(NULL) END) AS num_records_Q1, sum(CAST(CASE pipe_test.quarter WHEN Utf8View("Q2") THEN pipe_test.sales ELSE Int32(NULL) END AS Int64)) AS total_sales_Q2, count(CASE pipe_test.quarter WHEN Utf8View("Q2") THEN Int64(1) ELSE Int64(NULL) END) AS num_records_Q2]]
260+
02)--TableScan: pipe_test projection=[product, sales, quarter]
261+
physical_plan
262+
01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2]
263+
02)--CoalesceBatchesExec: target_batch_size=8192
264+
03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4
265+
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
266+
05)--------AggregateExec: mode=Partial, gby=[product@0 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2]
267+
06)----------DataSourceExec: partitions=1, partition_sizes=[1]
268+
269+
# With explicit pivot value alias
270+
query TT
271+
EXPLAIN SELECT * FROM pipe_test
272+
|> select product, sales, quarter
273+
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1' as q1, 'Q2'));
274+
----
275+
logical_plan
276+
01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(CAST(CASE pipe_test.quarter WHEN Utf8View("Q1") THEN pipe_test.sales ELSE Int32(NULL) END AS Int64)) AS total_sales_q1, count(CASE pipe_test.quarter WHEN Utf8View("Q1") THEN Int64(1) ELSE Int64(NULL) END) AS num_records_q1, sum(CAST(CASE pipe_test.quarter WHEN Utf8View("Q2") THEN pipe_test.sales ELSE Int32(NULL) END AS Int64)) AS total_sales_Q2, count(CASE pipe_test.quarter WHEN Utf8View("Q2") THEN Int64(1) ELSE Int64(NULL) END) AS num_records_Q2]]
277+
02)--TableScan: pipe_test projection=[product, sales, quarter]
278+
physical_plan
279+
01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2]
280+
02)--CoalesceBatchesExec: target_batch_size=8192
281+
03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4
282+
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
283+
05)--------AggregateExec: mode=Partial, gby=[product@0 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2]
284+
06)----------DataSourceExec: partitions=1, partition_sizes=[1]
285+
286+
# Aggregation functions with multiple parameters
287+
query error DataFusion error: This feature is not implemented: Only exactly one aggregate function argument is supported
288+
SELECT product, sales, quarter FROM pipe_test
289+
|> PIVOT(string_agg(sales, '_' order by sales) as agg FOR quarter IN ('Q1', 'Q2'));

docs/source/user-guide/sql/operators.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ DataFusion currently supports the following pipe operators:
636636
- [INTERSECT](#pipe_intersect)
637637
- [EXCEPT](#pipe_except)
638638
- [AGGREGATE](#pipe_aggregate)
639+
- [PIVOT](#pipe_pivot)
639640
640641
(pipe_where)=
641642
@@ -825,3 +826,30 @@ DataFusion currently supports the following pipe operators:
825826
| 3 |
826827
+-------+
827828
```
829+
830+
(pipe_pivot)=
831+
832+
### PIVOT
833+
834+
Rotates rows into columns.
835+
836+
```sql
837+
> (
838+
SELECT 'kale' AS product, 51 AS sales, 'Q1' AS quarter
839+
UNION ALL
840+
SELECT 'kale' AS product, 4 AS sales, 'Q1' AS quarter
841+
UNION ALL
842+
SELECT 'kale' AS product, 45 AS sales, 'Q2' AS quarter
843+
UNION ALL
844+
SELECT 'apple' AS product, 8 AS sales, 'Q1' AS quarter
845+
UNION ALL
846+
SELECT 'apple' AS product, 10 AS sales, 'Q2' AS quarter
847+
)
848+
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'));
849+
+---------+-----+-----+
850+
| product | Q1 | Q2 |
851+
+---------+-----+-----+
852+
| apple | 8 | 10 |
853+
| kale | 55 | 45 |
854+
+---------+-----+-----+
855+
```

0 commit comments

Comments
 (0)