Skip to content

Commit

Permalink
refactor: unify generic expr rewrite functions into the `datafusion_e…
Browse files Browse the repository at this point in the history
…xpr::expr_rewriter` (#6644)
  • Loading branch information
r4ntix authored Jun 13, 2023
1 parent 40c1b9b commit ee80d06
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 124 deletions.
3 changes: 1 addition & 2 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ use datafusion_expr::expr::{
self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField,
GroupingSet, InList, Like, ScalarUDF, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp};
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
use futures::future::BoxFuture;
Expand Down
121 changes: 119 additions & 2 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

//! Expression rewriter

use crate::expr::Sort;
use crate::logical_plan::Projection;
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use std::collections::HashMap;
Expand Down Expand Up @@ -233,13 +234,69 @@ fn coerce_exprs_for_schema(
.collect::<Result<_>>()
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: TreeNodeRewriter<N = Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
// call Expr::display_name() on a Expr::Sort will throw an error
Expr::Sort(Sort { expr, .. }) => name_for_alias(expr),
expr => expr.display_name(),
}
}

/// Ensure `expr` has the name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first))
}
expr => expr.alias(original_name),
})
}

#[cfg(test)]
mod test {
use super::*;
use crate::{col, lit};
use crate::{col, lit, Cast};
use arrow::datatypes::DataType;
use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter};
use datafusion_common::{DFField, DFSchema, ScalarValue};
use std::ops::Add;

#[derive(Default)]
struct RecordingRewriter {
Expand Down Expand Up @@ -387,4 +444,64 @@ mod test {
]
)
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)),
Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)),
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl TreeNodeRewriter for TestRewriter {
type N = Expr;

fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use std::time::Instant;
/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
/// the plan valid prior to the rest of the DataFusion optimization process.
///
/// For example, it may resolve [`Expr]s into more specific forms such
/// For example, it may resolve [`Expr`]s into more specific forms such
/// as a subquery reference, to do type coercion to ensure the types
/// of operands are correct.
///
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
ScalarUDF, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
Expand All @@ -47,7 +48,7 @@ use datafusion_expr::{
use datafusion_expr::{ExprSchemable, Signature};

use crate::analyzer::AnalyzerRule;
use crate::utils::{merge_schema, rewrite_preserving_name};
use crate::utils::merge_schema;

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! of expr can be added if needed.
//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
use crate::optimizer::ApplyOrder;
use crate::utils::{merge_schema, rewrite_preserving_name};
use crate::utils::merge_schema;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
Expand All @@ -28,6 +28,7 @@ use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
Expand Down
118 changes: 1 addition & 117 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
//! Collection of utility functions that are leveraged by the query optimizer rules

use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter};
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{BinaryExpr, Sort};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
Expand Down Expand Up @@ -214,15 +213,6 @@ pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.or(expr))
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
/// its predicate be all `predicates` ANDed.
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -285,51 +275,6 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
}
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: TreeNodeRewriter<N = Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort(Sort { expr, .. }) => name_for_alias(expr),
expr => expr.display_name(),
}
}

/// Ensure `expr` has the name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first))
}
expr => expr.alias(original_name),
})
}

/// merge inputs schema into a single schema.
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
if inputs.len() == 1 {
Expand Down Expand Up @@ -393,7 +338,6 @@ mod tests {
use datafusion_expr::expr::Cast;
use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;

#[test]
fn test_split_conjunction() {
Expand Down Expand Up @@ -534,64 +478,4 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)),
Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)),
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl TreeNodeRewriter for TestRewriter {
type N = Expr;

fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}

0 comments on commit ee80d06

Please sign in to comment.