Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify generic expr rewrite functions into the datafusion_expr::expr_rewriter #6644

Merged
merged 1 commit into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ use datafusion_expr::expr::{
self, AggregateFunction, AggregateUDF, Between, BinaryExpr, Cast, GetIndexedField,
GroupingSet, InList, Like, ScalarUDF, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp};
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
use futures::future::BoxFuture;
Expand Down
121 changes: 119 additions & 2 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

//! Expression rewriter

use crate::expr::Sort;
use crate::logical_plan::Projection;
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use std::collections::HashMap;
Expand Down Expand Up @@ -216,13 +217,69 @@ fn coerce_exprs_for_schema(
.collect::<Result<_>>()
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: TreeNodeRewriter<N = Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
// call Expr::display_name() on a Expr::Sort will throw an error
Expr::Sort(Sort { expr, .. }) => name_for_alias(expr),
expr => expr.display_name(),
}
}

/// Ensure `expr` has the name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first))
}
expr => expr.alias(original_name),
})
}

#[cfg(test)]
mod test {
use super::*;
use crate::{col, lit};
use crate::{col, lit, Cast};
use arrow::datatypes::DataType;
use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter};
use datafusion_common::{DFField, DFSchema, ScalarValue};
use std::ops::Add;

#[derive(Default)]
struct RecordingRewriter {
Expand Down Expand Up @@ -370,4 +427,64 @@ mod test {
]
)
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)),
Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)),
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl TreeNodeRewriter for TestRewriter {
type N = Expr;

fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use std::time::Instant;
/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
/// the plan valid prior to the rest of the DataFusion optimization process.
///
/// For example, it may resolve [`Expr]s into more specific forms such
/// For example, it may resolve [`Expr`]s into more specific forms such
/// as a subquery reference, to do type coercion to ensure the types
/// of operands are correct.
///
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
ScalarUDF, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
Expand All @@ -47,7 +48,7 @@ use datafusion_expr::{
use datafusion_expr::{ExprSchemable, Signature};

use crate::analyzer::AnalyzerRule;
use crate::utils::{merge_schema, rewrite_preserving_name};
use crate::utils::merge_schema;

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! of expr can be added if needed.
//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
use crate::optimizer::ApplyOrder;
use crate::utils::{merge_schema, rewrite_preserving_name};
use crate::utils::merge_schema;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
Expand All @@ -28,6 +28,7 @@ use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
Expand Down
118 changes: 1 addition & 117 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
//! Collection of utility functions that are leveraged by the query optimizer rules

use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter};
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{BinaryExpr, Sort};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference};
use datafusion_expr::logical_plan::LogicalPlanBuilder;
use datafusion_expr::utils::from_plan;
Expand Down Expand Up @@ -215,15 +214,6 @@ pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
filters.into_iter().reduce(|accum, expr| accum.or(expr))
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
/// its predicate be all `predicates` ANDed.
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -286,51 +276,6 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
}
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: TreeNodeRewriter<N = Expr>,
{
let original_name = name_for_alias(&expr)?;
let expr = expr.rewrite(rewriter)?;
add_alias_if_changed(original_name, expr)
}

/// Return the name to use for the specific Expr, recursing into
/// `Expr::Sort` as appropriate
fn name_for_alias(expr: &Expr) -> Result<String> {
match expr {
Expr::Sort(Sort { expr, .. }) => name_for_alias(expr),
expr => expr.display_name(),
}
}

/// Ensure `expr` has the name as `original_name` by adding an
/// alias if necessary.
fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
let new_name = name_for_alias(&expr)?;

if new_name == original_name {
return Ok(expr);
}

Ok(match expr {
Expr::Sort(Sort {
expr,
asc,
nulls_first,
}) => {
let expr = add_alias_if_changed(original_name, *expr)?;
Expr::Sort(Sort::new(Box::new(expr), asc, nulls_first))
}
expr => expr.alias(original_name),
})
}

/// merge inputs schema into a single schema.
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
if inputs.len() == 1 {
Expand Down Expand Up @@ -417,7 +362,6 @@ mod tests {
use datafusion_expr::expr::Cast;
use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;

#[test]
fn test_split_conjunction() {
Expand Down Expand Up @@ -558,64 +502,4 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}

#[test]
fn test_rewrite_preserving_name() {
test_rewrite(col("a"), col("a"));

test_rewrite(col("a"), col("b"));

// cast data types
test_rewrite(
col("a"),
Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
);

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// SortExpr a+1 ==> b + 2
test_rewrite(
Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)),
Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)),
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
}

impl TreeNodeRewriter for TestRewriter {
type N = Expr;

fn mutate(&mut self, _: Expr) -> Result<Expr> {
Ok(self.rewrite_to.clone())
}
}

let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = match &expr_from {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

let new_name = match &expr {
Expr::Sort(Sort { expr, .. }) => expr.display_name(),
expr => expr.display_name(),
}
.unwrap();

assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
}