Skip to content

Commit

Permalink
Add basic support for unnest unparsing (apache#13129)
Browse files Browse the repository at this point in the history
* Add basic support for `unnest` unparsing (#45)

* Fix taplo cargo check
  • Loading branch information
sgrebnov authored Oct 28, 2024
1 parent 132b232 commit 1fd6116
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 32 deletions.
1 change: 1 addition & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ strum = { version = "0.26.1", features = ["derive"] }
ctor = { workspace = true }
datafusion-functions = { workspace = true, default-features = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-nested = { workspace = true }
datafusion-functions-window = { workspace = true }
env_logger = { workspace = true }
paste = "^1.0"
Expand Down
35 changes: 34 additions & 1 deletion datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_expr::expr::Unnest;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName,
Expand Down Expand Up @@ -466,7 +467,7 @@ impl Unparser<'_> {
Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string())))
}
Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col),
Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"),
Expr::Unnest(unnest) => self.unnest_to_sql(unnest),
}
}

Expand Down Expand Up @@ -1340,6 +1341,29 @@ impl Unparser<'_> {
}
}

/// Converts an UNNEST operation to an AST expression by wrapping it as a function call,
/// since there is no direct representation for UNNEST in the AST.
fn unnest_to_sql(&self, unnest: &Unnest) -> Result<ast::Expr> {
let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?;

Ok(ast::Expr::Function(Function {
name: ast::ObjectName(vec![Ident {
value: "UNNEST".to_string(),
quote_style: None,
}]),
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
duplicate_treatment: None,
args,
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
parameters: ast::FunctionArguments::None,
}))
}

fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result<ast::DataType> {
match data_type {
DataType::Null => {
Expand Down Expand Up @@ -1855,6 +1879,15 @@ mod tests {
}),
r#"CAST(a AS DECIMAL(12,0))"#,
),
(
Expr::Unnest(Unnest {
expr: Box::new(Expr::Column(Column {
relation: Some(TableReference::partial("schema", "table")),
name: "array_col".to_string(),
})),
}),
r#"UNNEST("schema"."table".array_col)"#,
),
];

for (expr, expected) in tests {
Expand Down
53 changes: 42 additions & 11 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ use super::{
subquery_alias_inner_query_and_columns, TableAliasRewriter,
},
utils::{
find_agg_node_within_select, find_window_nodes_within_select,
unproject_sort_expr, unproject_window_exprs,
find_agg_node_within_select, find_unnest_node_within_select,
find_window_nodes_within_select, unproject_sort_expr, unproject_unnest_expr,
unproject_window_exprs,
},
Unparser,
};
Expand Down Expand Up @@ -173,15 +174,24 @@ impl Unparser<'_> {
p: &Projection,
select: &mut SelectBuilder,
) -> Result<()> {
let mut exprs = p.expr.clone();

// If an Unnest node is found within the select, find and unproject the unnest column
if let Some(unnest) = find_unnest_node_within_select(plan) {
exprs = exprs
.into_iter()
.map(|e| unproject_unnest_expr(e, unnest))
.collect::<Result<Vec<_>>>()?;
};

match (
find_agg_node_within_select(plan, true),
find_window_nodes_within_select(plan, None, true),
) {
(Some(agg), window) => {
let window_option = window.as_deref();
let items = p
.expr
.iter()
let items = exprs
.into_iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?;
self.select_item_to_sql(&unproj)
Expand All @@ -198,9 +208,8 @@ impl Unparser<'_> {
));
}
(None, Some(window)) => {
let items = p
.expr
.iter()
let items = exprs
.into_iter()
.map(|proj_expr| {
let unproj = unproject_window_exprs(proj_expr, &window)?;
self.select_item_to_sql(&unproj)
Expand All @@ -210,8 +219,7 @@ impl Unparser<'_> {
select.projection(items);
}
_ => {
let items = p
.expr
let items = exprs
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
Expand Down Expand Up @@ -318,7 +326,8 @@ impl Unparser<'_> {
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg, None)?;
let unprojected =
unproject_agg_exprs(filter.predicate.clone(), agg, None)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {
Expand Down Expand Up @@ -596,6 +605,28 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"),
LogicalPlan::Unnest(unnest) => {
if !unnest.struct_type_columns.is_empty() {
return internal_err!(
"Struct type columns are not currently supported in UNNEST: {:?}",
unnest.struct_type_columns
);
}

// In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip.
// Otherwise, there will be a duplicate SELECT clause.
// | Projection: table.col1, UNNEST(table.col2)
// | Unnest: UNNEST(table.col2)
// | Projection: table.col1, table.col2 AS UNNEST(table.col2)
// | Filter: table.col3 = Int64(3)
// | TableScan: table projection=None
if let LogicalPlan::Projection(p) = unnest.input.as_ref() {
// continue with projection input
self.select_to_sql_recursively(&p.input, query, select, relation)
} else {
internal_err!("Unnest input is not a Projection: {unnest:?}")
}
}
_ => not_impl_err!("Unsupported operator: {plan:?}"),
}
}
Expand Down
87 changes: 68 additions & 19 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use datafusion_common::{
Column, Result, ScalarValue,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection, SortExpr,
Window,
expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Projection,
SortExpr, Unnest, Window,
};
use sqlparser::ast;

Expand Down Expand Up @@ -62,6 +62,28 @@ pub(crate) fn find_agg_node_within_select(
}
}

/// Recursively searches children of [LogicalPlan] to find Unnest node if exist
pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> {
// Note that none of the nodes that have a corresponding node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};

if let LogicalPlan::Unnest(unnest) = input {
Some(unnest)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
None
} else {
find_unnest_node_within_select(input)
}
}

/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If Window node is not found prior to this or at all before reaching the end
Expand Down Expand Up @@ -104,26 +126,54 @@ pub(crate) fn find_window_nodes_within_select<'a>(
}
}

/// Recursively identify Column expressions and transform them into the appropriate unnest expression
///
/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)"
/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL])
pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result<Expr> {
expr.transform(|sub_expr| {
if let Expr::Column(col_ref) = &sub_expr {
// Check if the column is among the columns to run unnest on.
// Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting.
if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) {
if let Ok(idx) = unnest.schema.index_of_column(col_ref) {
if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() {
if let Some(unprojected_expr) = expr.get(idx) {
let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone()));
return Ok(Transformed::yes(unnest_expr));
}
}
}
return internal_err!(
"Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name
);
}
}

Ok(Transformed::no(sub_expr))

}).map(|e| e.data)
}

/// Recursively identify all Column expressions and transform them into the appropriate
/// aggregate expression contained in agg.
///
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
pub(crate) fn unproject_agg_exprs(
expr: &Expr,
expr: Expr,
agg: &Aggregate,
windows: Option<&[&Window]>,
) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
expr.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
Ok(Transformed::yes(unprojected_expr.clone()))
} else if let Some(unprojected_expr) =
windows.and_then(|w| find_window_expr(w, &c.name).cloned())
{
// Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
return Ok(Transformed::yes(unproject_agg_exprs(&unprojected_expr, agg, None)?));
return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?));
} else {
internal_err!(
"Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
Expand All @@ -141,20 +191,19 @@ pub(crate) fn unproject_agg_exprs(
///
/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
/// into an actual window expression as identified in the window node.
pub(crate) fn unproject_window_exprs(expr: &Expr, windows: &[&Window]) -> Result<Expr> {
expr.clone()
.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = find_window_expr(windows, &c.name) {
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(Expr::Column(c)))
}
pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result<Expr> {
expr.transform(|sub_expr| {
if let Expr::Column(c) = sub_expr {
if let Some(unproj) = find_window_expr(windows, &c.name) {
Ok(Transformed::yes(unproj.clone()))
} else {
Ok(Transformed::no(sub_expr))
Ok(Transformed::no(Expr::Column(c)))
}
})
.map(|e| e.data)
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}

fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
Expand Down Expand Up @@ -218,7 +267,7 @@ pub(crate) fn unproject_sort_expr(
// In case of aggregation there could be columns containing aggregation functions we need to unproject
if let Some(agg) = agg {
if agg.schema.is_column_from_schema(col_ref) {
let new_expr = unproject_agg_exprs(&sort_expr.expr, agg, None)?;
let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?;
sort_expr.expr = new_expr;
return Ok(sort_expr);
}
Expand Down
19 changes: 18 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_u
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
use datafusion_functions_nested::make_array::make_array_udf;
use datafusion_functions_window::rank::rank_udwf;
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
Expand Down Expand Up @@ -711,7 +712,8 @@ where
.with_aggregate_function(max_udaf())
.with_aggregate_function(grouping_udaf())
.with_window_function(rank_udwf())
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone())),
.with_scalar_function(Arc::new(unicode::substr().as_ref().clone()))
.with_scalar_function(make_array_udf()),
};
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
Expand Down Expand Up @@ -1084,3 +1086,18 @@ FROM person
GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(),
);
}

#[test]
fn test_unnest_to_sql() {
sql_round_trip(
GenericDialect {},
r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#,
r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#,
);

sql_round_trip(
GenericDialect {},
r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#,
r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#,
);
}

0 comments on commit 1fd6116

Please sign in to comment.