diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 38e125dc9338..5fb77f75c1c5 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -67,11 +67,10 @@ use datafusion_expr::expr::{ self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, WindowFunction, }; -use datafusion_expr::expr_rewriter::unnormalize_cols; +use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp}; use datafusion_expr::{WindowFrame, WindowFrameBound}; -use datafusion_optimizer::utils::unalias; use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9d2144328d02..1e0d216cad31 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,9 +17,10 @@ //! Expression rewriter +use crate::expr::Sort; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use std::collections::HashMap; @@ -233,13 +234,69 @@ fn coerce_exprs_for_schema( .collect::>() } +/// Recursively un-alias an expressions +#[inline] +pub fn unalias(expr: Expr) -> Expr { + match expr { + Expr::Alias(sub_expr, _) => unalias(*sub_expr), + _ => expr, + } +} + +/// Rewrites `expr` using `rewriter`, ensuring that the output has the +/// same name as `expr` prior to rewrite, adding an alias if necessary. +/// +/// This is important when optimizing plans to ensure the output +/// schema of plan nodes don't change after optimization +pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result +where + R: TreeNodeRewriter, +{ + let original_name = name_for_alias(&expr)?; + let expr = expr.rewrite(rewriter)?; + add_alias_if_changed(original_name, expr) +} + +/// Return the name to use for the specific Expr, recursing into +/// `Expr::Sort` as appropriate +fn name_for_alias(expr: &Expr) -> Result { + match expr { + // call Expr::display_name() on a Expr::Sort will throw an error + Expr::Sort(Sort { expr, .. }) => name_for_alias(expr), + expr => expr.display_name(), + } +} + +/// Ensure `expr` has the name as `original_name` by adding an +/// alias if necessary. +fn add_alias_if_changed(original_name: String, expr: Expr) -> Result { + let new_name = name_for_alias(&expr)?; + + if new_name == original_name { + return Ok(expr); + } + + Ok(match expr { + Expr::Sort(Sort { + expr, + asc, + nulls_first, + }) => { + let expr = add_alias_if_changed(original_name, *expr)?; + Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first)) + } + expr => expr.alias(original_name), + }) +} + #[cfg(test)] mod test { use super::*; - use crate::{col, lit}; + use crate::{col, lit, Cast}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; + use std::ops::Add; #[derive(Default)] struct RecordingRewriter { @@ -387,4 +444,64 @@ mod test { ] ) } + + #[test] + fn test_rewrite_preserving_name() { + test_rewrite(col("a"), col("a")); + + test_rewrite(col("a"), col("b")); + + // cast data types + test_rewrite( + col("a"), + Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + ); + + // change literal type from i32 to i64 + test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); + + // SortExpr a+1 ==> b + 2 + test_rewrite( + Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), + Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), + ); + } + + /// rewrites `expr_from` to `rewrite_to` using + /// `rewrite_preserving_name` verifying the result is `expected_expr` + fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { + struct TestRewriter { + rewrite_to: Expr, + } + + impl TreeNodeRewriter for TestRewriter { + type N = Expr; + + fn mutate(&mut self, _: Expr) -> Result { + Ok(self.rewrite_to.clone()) + } + } + + let mut rewriter = TestRewriter { + rewrite_to: rewrite_to.clone(), + }; + let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); + + let original_name = match &expr_from { + Expr::Sort(Sort { expr, .. }) => expr.display_name(), + expr => expr.display_name(), + } + .unwrap(); + + let new_name = match &expr { + Expr::Sort(Sort { expr, .. }) => expr.display_name(), + expr => expr.display_name(), + } + .unwrap(); + + assert_eq!( + original_name, new_name, + "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" + ) + } } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 436bb3a06044..14d5ddf47378 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -40,7 +40,7 @@ use std::time::Instant; /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. /// -/// For example, it may resolve [`Expr]s into more specific forms such +/// For example, it may resolve [`Expr`]s into more specific forms such /// as a subquery reference, to do type coercion to ensure the types /// of operands are correct. /// diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3ee6a2401b02..e6023c469829 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,6 +28,7 @@ use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, ScalarUDF, WindowFunction, }; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ @@ -47,7 +48,7 @@ use datafusion_expr::{ use datafusion_expr::{ExprSchemable, Signature}; use crate::analyzer::AnalyzerRule; -use crate::utils::{merge_schema, rewrite_preserving_name}; +use crate::utils::merge_schema; #[derive(Default)] pub struct TypeCoercion {} diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 2c70ad0e9acd..daa695f77144 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,7 +19,7 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::{merge_schema, rewrite_preserving_name}; +use crate::utils::merge_schema; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -28,6 +28,7 @@ use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; +use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 32ef4e087923..50f753f8c972 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,10 +18,9 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::{BinaryExpr, Sort}; +use datafusion_expr::expr::BinaryExpr; use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ @@ -214,15 +213,6 @@ pub fn disjunction(filters: impl IntoIterator) -> Option { filters.into_iter().reduce(|accum, expr| accum.or(expr)) } -/// Recursively un-alias an expressions -#[inline] -pub fn unalias(expr: Expr) -> Expr { - match expr { - Expr::Alias(sub_expr, _) => unalias(*sub_expr), - _ => expr, - } -} - /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with /// its predicate be all `predicates` ANDed. pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { @@ -285,51 +275,6 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { } } -/// Rewrites `expr` using `rewriter`, ensuring that the output has the -/// same name as `expr` prior to rewrite, adding an alias if necessary. -/// -/// This is important when optimizing plans to ensure the output -/// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result -where - R: TreeNodeRewriter, -{ - let original_name = name_for_alias(&expr)?; - let expr = expr.rewrite(rewriter)?; - add_alias_if_changed(original_name, expr) -} - -/// Return the name to use for the specific Expr, recursing into -/// `Expr::Sort` as appropriate -fn name_for_alias(expr: &Expr) -> Result { - match expr { - Expr::Sort(Sort { expr, .. }) => name_for_alias(expr), - expr => expr.display_name(), - } -} - -/// Ensure `expr` has the name as `original_name` by adding an -/// alias if necessary. -fn add_alias_if_changed(original_name: String, expr: Expr) -> Result { - let new_name = name_for_alias(&expr)?; - - if new_name == original_name { - return Ok(expr); - } - - Ok(match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = add_alias_if_changed(original_name, *expr)?; - Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first)) - } - expr => expr.alias(original_name), - }) -} - /// merge inputs schema into a single schema. pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { if inputs.len() == 1 { @@ -393,7 +338,6 @@ mod tests { use datafusion_expr::expr::Cast; use datafusion_expr::{col, lit, utils::expr_to_columns}; use std::collections::HashSet; - use std::ops::Add; #[test] fn test_split_conjunction() { @@ -534,64 +478,4 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } - - #[test] - fn test_rewrite_preserving_name() { - test_rewrite(col("a"), col("a")); - - test_rewrite(col("a"), col("b")); - - // cast data types - test_rewrite( - col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), - ); - - // change literal type from i32 to i64 - test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); - - // SortExpr a+1 ==> b + 2 - test_rewrite( - Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), - Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), - ); - } - - /// rewrites `expr_from` to `rewrite_to` using - /// `rewrite_preserving_name` verifying the result is `expected_expr` - fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { - struct TestRewriter { - rewrite_to: Expr, - } - - impl TreeNodeRewriter for TestRewriter { - type N = Expr; - - fn mutate(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) - } - } - - let mut rewriter = TestRewriter { - rewrite_to: rewrite_to.clone(), - }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); - - let original_name = match &expr_from { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); - - let new_name = match &expr { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); - - assert_eq!( - original_name, new_name, - "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" - ) - } }