Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ pub trait Dialect: Send + Sync {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Date
}

/// Does the dialect support specifying column aliases as part of alias table definition?
/// (SELECT col1, col2 from my_table) AS my_table_alias(col1_alias, col2_alias)
fn supports_column_alias_in_table_alias(&self) -> bool {
true
}
}

/// `IntervalStyle` to use for unparsing
Expand Down Expand Up @@ -221,6 +227,10 @@ impl Dialect for SqliteDialect {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::Text
}

fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
}

pub struct CustomDialect {
Expand All @@ -236,6 +246,7 @@ pub struct CustomDialect {
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: sqlparser::ast::DataType,
supports_column_alias_in_table_alias: bool,
}

impl Default for CustomDialect {
Expand All @@ -256,6 +267,7 @@ impl Default for CustomDialect {
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
}
}
}
Expand Down Expand Up @@ -323,6 +335,10 @@ impl Dialect for CustomDialect {
fn date32_cast_dtype(&self) -> sqlparser::ast::DataType {
self.date32_cast_dtype.clone()
}

fn supports_column_alias_in_table_alias(&self) -> bool {
self.supports_column_alias_in_table_alias
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down Expand Up @@ -352,6 +368,7 @@ pub struct CustomDialectBuilder {
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
}

impl Default for CustomDialectBuilder {
Expand All @@ -378,6 +395,7 @@ impl CustomDialectBuilder {
TimezoneInfo::WithTimeZone,
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
}
}

Expand All @@ -395,6 +413,8 @@ impl CustomDialectBuilder {
timestamp_cast_dtype: self.timestamp_cast_dtype,
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
date32_cast_dtype: self.date32_cast_dtype,
supports_column_alias_in_table_alias: self
.supports_column_alias_in_table_alias,
}
}

Expand Down Expand Up @@ -482,4 +502,13 @@ impl CustomDialectBuilder {
self.date32_cast_dtype = date32_cast_dtype;
self
}

/// Customize the dialect to supports column aliases as part of alias table definition
pub fn with_supports_column_alias_in_table_alias(
mut self,
supports_column_alias_in_table_alias: bool,
) -> Self {
self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias;
self
}
}
36 changes: 31 additions & 5 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::{internal_err, not_impl_err, Column, DataFusionError, Result};
use datafusion_common::{
internal_err, not_impl_err, plan_err, Column, DataFusionError, Result,
};
use datafusion_expr::{
expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection,
SortExpr,
Expand All @@ -30,7 +32,8 @@ use super::{
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
rewrite::{
normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields,
inject_column_aliases, normalize_union_schema,
rewrite_plan_for_sort_on_non_projected_fields,
subquery_alias_inner_query_and_columns,
},
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Expand Down Expand Up @@ -450,10 +453,33 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::SubqueryAlias(plan_alias) => {
// Handle bottom-up to allocate relation
let (plan, columns) = subquery_alias_inner_query_and_columns(plan_alias);
let (plan, mut columns) =
subquery_alias_inner_query_and_columns(plan_alias);

if !columns.is_empty()
&& !self.dialect.supports_column_alias_in_table_alias()
{
// if columns are returned then the plan corresponds to a projection
let LogicalPlan::Projection(inner_p) = plan else {
return plan_err!(
"Inner projection for subquery alias is expected"
);
};

// Instead of specifying column aliases as part of the outer table, inject them directly into the inner projection
let rewritten_plan = inject_column_aliases(inner_p, columns);
columns = vec![];

self.select_to_sql_recursively(
&rewritten_plan,
query,
select,
relation,
)?;
} else {
self.select_to_sql_recursively(plan, query, select, relation)?;
}

self.select_to_sql_recursively(plan, query, select, relation)?;
relation.alias(Some(
self.new_table_alias(plan_alias.alias.table().to_string(), columns),
));
Expand Down
36 changes: 35 additions & 1 deletion datafusion/sql/src/unparser/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode},
Result,
};
use datafusion_expr::tree_node::transform_sort_vec;
use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
use sqlparser::ast::Ident;

Expand Down Expand Up @@ -257,6 +257,40 @@ pub(super) fn subquery_alias_inner_query_and_columns(
(outer_projections.input.as_ref(), columns)
}

/// Injects column aliases into the projection of a logical plan by wrapping `Expr::Column` expressions
/// with `Expr::Alias` using the provided list of aliases. Non-column expressions are left unchanged.
///
/// Example:
/// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to
/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table`
pub(super) fn inject_column_aliases(
projection: &datafusion_expr::Projection,
aliases: impl IntoIterator<Item = Ident>,
) -> LogicalPlan {
let mut updated_projection = projection.clone();

let new_exprs = updated_projection
.expr
.into_iter()
.zip(aliases)
.map(|(expr, col_alias)| match expr {
Expr::Column(col) => {
let relation = col.relation.clone();
Expr::Alias(Alias {
expr: Box::new(Expr::Column(col)),
relation,
name: col_alias.value,
})
}
_ => expr,
})
.collect::<Vec<_>>();

updated_projection.expr = new_exprs;

LogicalPlan::Projection(updated_projection)
}

fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
match logical_plan {
LogicalPlan::Projection(p) => Some(p),
Expand Down
16 changes: 14 additions & 2 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_expr::{col, table_scan};
use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_sql::unparser::dialect::{
DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect,
MySqlDialect as UnparserMySqlDialect,
MySqlDialect as UnparserMySqlDialect, SqliteDialect,
};
use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser};

Expand Down Expand Up @@ -406,7 +406,19 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
}
},
TestStatementWithDialect {
sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)",
expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)",
expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(SqliteDialect {}),
},
];

for query in tests {
Expand Down