Skip to content

Commit 0537da7

Browse files
committed
Support aliases in ConstEvaluator (apache#14734)
Not sure why they are not supported. It seems that if we're not careful, some transformations can introduce aliases nested inside other expressions.
1 parent 0405192 commit 0537da7

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

datafusion/core/tests/expr_api/simplification.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,33 @@ fn test_const_evaluator() {
365365
);
366366
}
367367

368+
#[test]
369+
fn test_const_evaluator_alias() {
370+
// true --> true
371+
test_evaluate(lit(true).alias("a"), lit(true));
372+
// true or true --> true
373+
test_evaluate(lit(true).alias("a").or(lit(true).alias("b")), lit(true));
374+
// "foo" == "foo" --> true
375+
test_evaluate(lit("foo").alias("a").eq(lit("foo").alias("b")), lit(true));
376+
// c = 1 + 2 --> c + 3
377+
test_evaluate(
378+
col("c")
379+
.alias("a")
380+
.eq(lit(1).alias("b") + lit(2).alias("c")),
381+
col("c").alias("a").eq(lit(3)),
382+
);
383+
// (foo != foo) OR (c = 1) --> false OR (c = 1)
384+
test_evaluate(
385+
lit("foo")
386+
.alias("a")
387+
.not_eq(lit("foo").alias("b"))
388+
.alias("c")
389+
.or(col("c").alias("d").eq(lit(1).alias("e")))
390+
.alias("f"),
391+
col("c").alias("d").eq(lit(1)).alias("f"),
392+
);
393+
}
394+
368395
#[test]
369396
fn test_const_evaluator_scalar_functions() {
370397
// concat("foo", "bar") --> "foobar"

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@ use datafusion_expr::{
4545
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
4646
use indexmap::IndexSet;
4747

48+
use super::inlist_simplifier::ShortenInListSimplifier;
49+
use super::utils::*;
4850
use crate::analyzer::type_coercion::TypeCoercionRewriter;
4951
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
5052
use crate::simplify_expressions::regex::simplify_regex_expr;
5153
use crate::simplify_expressions::SimplifyInfo;
5254

53-
use super::inlist_simplifier::ShortenInListSimplifier;
54-
use super::utils::*;
55-
5655
/// This structure handles API for expression simplification
5756
///
5857
/// Provides simplification information based on DFSchema and
@@ -514,30 +513,27 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
514513

515514
// NB: do not short circuit recursion even if we find a non
516515
// evaluatable node (so we can fold other children, args to
517-
// functions, etc)
516+
// functions, etc.)
518517
Ok(Transformed::no(expr))
519518
}
520519

521520
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
522521
match self.can_evaluate.pop() {
523-
// Certain expressions such as `CASE` and `COALESCE` are short circuiting
524-
// and may not evaluate all their sub expressions. Thus if
525-
// if any error is countered during simplification, return the original
522+
// Certain expressions such as `CASE` and `COALESCE` are short-circuiting
523+
// and may not evaluate all their sub expressions. Thus, if
524+
// any error is countered during simplification, return the original
526525
// so that normal evaluation can occur
527-
Some(true) => {
528-
let result = self.evaluate_to_scalar(expr);
529-
match result {
530-
ConstSimplifyResult::Simplified(s) => {
531-
Ok(Transformed::yes(Expr::Literal(s)))
532-
}
533-
ConstSimplifyResult::NotSimplified(s) => {
534-
Ok(Transformed::no(Expr::Literal(s)))
535-
}
536-
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
537-
Ok(Transformed::yes(expr))
538-
}
526+
Some(true) => match self.evaluate_to_scalar(expr) {
527+
ConstSimplifyResult::Simplified(s) => {
528+
Ok(Transformed::yes(Expr::Literal(s)))
539529
}
540-
}
530+
ConstSimplifyResult::NotSimplified(s) => {
531+
Ok(Transformed::no(Expr::Literal(s)))
532+
}
533+
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
534+
Ok(Transformed::yes(expr))
535+
}
536+
},
541537
Some(false) => Ok(Transformed::no(expr)),
542538
_ => internal_err!("Failed to pop can_evaluate"),
543539
}
@@ -585,9 +581,7 @@ impl<'a> ConstEvaluator<'a> {
585581
// added they can be checked for their ability to be evaluated
586582
// at plan time
587583
match expr {
588-
// Has no runtime cost, but needed during planning
589-
Expr::Alias(..)
590-
| Expr::AggregateFunction { .. }
584+
Expr::AggregateFunction { .. }
591585
| Expr::ScalarVariable(_, _)
592586
| Expr::Column(_)
593587
| Expr::OuterReferenceColumn(_, _)
@@ -602,6 +596,7 @@ impl<'a> ConstEvaluator<'a> {
602596
Self::volatility_ok(func.signature().volatility)
603597
}
604598
Expr::Literal(_)
599+
| Expr::Alias(..)
605600
| Expr::Unnest(_)
606601
| Expr::BinaryExpr { .. }
607602
| Expr::Not(_)

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ query TT
830830
explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1
831831
----
832832
logical_plan
833-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) AS _cnt ELSE __scalar_sq_1._cnt END AS cnt
833+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1._cnt END AS cnt
834834
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
835835
03)----TableScan: t1 projection=[t1_id, t1_int]
836836
04)----SubqueryAlias: __scalar_sq_1
@@ -851,7 +851,7 @@ query TT
851851
explain SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1
852852
----
853853
logical_plan
854-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS _cnt ELSE __scalar_sq_1._cnt END AS _cnt
854+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) ELSE __scalar_sq_1._cnt END AS _cnt
855855
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
856856
03)----TableScan: t1 projection=[t1_id, t1_int]
857857
04)----SubqueryAlias: __scalar_sq_1
@@ -918,7 +918,7 @@ query TT
918918
explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1
919919
----
920920
logical_plan
921-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
921+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
922922
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
923923
03)----TableScan: t1 projection=[t1_id, t1_int]
924924
04)----SubqueryAlias: __scalar_sq_1

0 commit comments

Comments
 (0)