Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression Simplifier doesn't consider associativity ((i + 1) + 2) is not simplified to i + 3) #11594

Open
alamb opened this issue Jul 22, 2024 · 12 comments
Labels
enhancement New feature or request

Comments

@alamb
Copy link
Contributor

alamb commented Jul 22, 2024

Is your feature request related to a problem or challenge?

DataFusion will simplify expressions like this: i + (1 + 2) => i + 3

However, it will not simplify i + 1 + 2 (remains i + 1 + 2)

You can see this in the explain plans

> explain select column1 + (1 + 2) from values (100);
+---------------+-----------------------------------------------------------------------+
| plan_type     | plan                                                                  |
+---------------+-----------------------------------------------------------------------+
| logical_plan  | Projection: column1 + Int64(3) AS column1 + Int64(1) + Int64(2)       |
|               |   Values: (Int64(100))                                                |
| physical_plan | ProjectionExec: expr=[column1@0 + 3 as column1 + Int64(1) + Int64(2)] | <-- computed expression is `column1@0 + 3`
|               |   ValuesExec                                                          |
|               |                                                                       |
+---------------+-----------------------------------------------------------------------+
2 row(s) fetched.
Elapsed 0.013 seconds.

> explain select column1 + 1 + 2 from values (100);
+---------------+---------------------------------------------------------------------------+
| plan_type     | plan                                                                      |
+---------------+---------------------------------------------------------------------------+
| logical_plan  | Projection: column1 + Int64(1) + Int64(2)                                 |
|               |   Values: (Int64(100))                                                    |
| physical_plan | ProjectionExec: expr=[column1@0 + 1 + 2 as column1 + Int64(1) + Int64(2)] | <-- expression is STILL `column1@0 + 1 + 2`
|               |   ValuesExec                                                              |
|               |                                                                           |
+---------------+---------------------------------------------------------------------------+
2 row(s) fetched.
Elapsed 0.002 seconds.

@timsaucer has identified the problem

I don’t have time to look through the code right now, but I would guess the operations happen left to right when you don’t have parentheses to indicate order. So the second would be equivalent to (col(“i”) + lit(1)) + lit(2). That is, would guess it isn’t checking for associativity of operations.

So in this case i + 1 + 2 is parsed as (i + 1) + 2 and since (i + 1) can't be reduced, the entire expression isn't either

Describe the solution you'd like

It would be nice to properly support this simplification

Describe alternatives you've considered

We'll have to consider how to apply associativity (I am sure there is prior art in this area) as to solve the above issue it would need to potentially reorder the operations so constants are together and then also re-apply the const evaluator

impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
type Node = Expr;
/// rewrite the expression simplifying any constant expressions
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
use datafusion_expr::Operator::{
And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor,
Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
RegexNotIMatch, RegexNotMatch,
};
let info = self.info;
Ok(match expr {
//
// Rules for Eq
//
// true = A --> A
// false = A --> !A
// null = A --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: Eq,
right,
}) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
Transformed::yes(match as_bool_lit(*left)? {
Some(true) => *right,
Some(false) => Expr::Not(right),
None => lit_bool_null(),
})
}
// A = true --> A
// A = false --> !A
// A = null --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: Eq,
right,
}) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
Transformed::yes(match as_bool_lit(*right)? {
Some(true) => *left,
Some(false) => Expr::Not(left),
None => lit_bool_null(),
})
}
// Rules for NotEq
//
// true != A --> !A
// false != A --> A
// null != A --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: NotEq,
right,
}) if is_bool_lit(&left) && info.is_boolean_type(&right)? => {
Transformed::yes(match as_bool_lit(*left)? {
Some(true) => Expr::Not(right),
Some(false) => *right,
None => lit_bool_null(),
})
}
// A != true --> !A
// A != false --> A
// A != null --> null,
Expr::BinaryExpr(BinaryExpr {
left,
op: NotEq,
right,
}) if is_bool_lit(&right) && info.is_boolean_type(&left)? => {
Transformed::yes(match as_bool_lit(*right)? {
Some(true) => Expr::Not(left),
Some(false) => *left,
None => lit_bool_null(),
})
}
//
// Rules for OR
//
// true OR A --> true (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right: _,
}) if is_true(&left) => Transformed::yes(*left),
// false OR A --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_false(&left) => Transformed::yes(*right),
// A OR true --> true (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Or,
right,
}) if is_true(&right) => Transformed::yes(*right),
// A OR false --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_false(&right) => Transformed::yes(*left),
// A OR !A ---> true (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_not_of(&right, &left) && !info.nullable(&left)? => {
Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true))))
}
// !A OR A ---> true (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_not_of(&left, &right) && !info.nullable(&right)? => {
Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true))))
}
// (..A..) OR A --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
// A OR (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
// A OR (A AND B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if !info.nullable(&right)? && is_op_with(And, &right, &left) => {
Transformed::yes(*left)
}
// (A AND B) OR A --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if !info.nullable(&left)? && is_op_with(And, &left, &right) => {
Transformed::yes(*right)
}
//
// Rules for AND
//
// true AND A --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_true(&left) => Transformed::yes(*right),
// false AND A --> false (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right: _,
}) if is_false(&left) => Transformed::yes(*left),
// A AND true --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_true(&right) => Transformed::yes(*left),
// A AND false --> false (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left: _,
op: And,
right,
}) if is_false(&right) => Transformed::yes(*right),
// A AND !A ---> false (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_not_of(&right, &left) && !info.nullable(&left)? => {
Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false))))
}
// !A AND A ---> false (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_not_of(&left, &right) && !info.nullable(&right)? => {
Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false))))
}
// (..A..) AND A --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if expr_contains(&left, &right, And) => Transformed::yes(*left),
// A AND (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if expr_contains(&right, &left, And) => Transformed::yes(*right),
// A AND (A OR B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => {
Transformed::yes(*left)
}
// (A OR B) AND A --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => {
Transformed::yes(*right)
}
//
// Rules for Multiply
//
// A * 1 --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Multiply,
right,
}) if is_one(&right) => Transformed::yes(*left),
// 1 * A --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Multiply,
right,
}) if is_one(&left) => Transformed::yes(*right),
// A * null --> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Multiply,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null * A --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: Multiply,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
Expr::BinaryExpr(BinaryExpr {
left,
op: Multiply,
right,
}) if !info.nullable(&left)?
&& !info.get_data_type(&left)?.is_floating()
&& is_zero(&right) =>
{
Transformed::yes(*right)
}
// 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN)
Expr::BinaryExpr(BinaryExpr {
left,
op: Multiply,
right,
}) if !info.nullable(&right)?
&& !info.get_data_type(&right)?.is_floating()
&& is_zero(&left) =>
{
Transformed::yes(*left)
}
//
// Rules for Divide
//
// A / 1 --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Divide,
right,
}) if is_one(&right) => Transformed::yes(*left),
// null / A --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: Divide,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A / null --> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Divide,
right,
}) if is_null(&right) => Transformed::yes(*right),
//
// Rules for Modulo
//
// A % null --> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Modulo,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null % A --> null
Expr::BinaryExpr(BinaryExpr {
left,
op: Modulo,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
Expr::BinaryExpr(BinaryExpr {
left,
op: Modulo,
right,
}) if !info.nullable(&left)?
&& !info.get_data_type(&left)?.is_floating()
&& is_one(&right) =>
{
Transformed::yes(lit(0))
}
//
// Rules for BitwiseAnd
//
// A & null -> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: BitwiseAnd,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null & A -> null
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A & 0 -> 0 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right),
// 0 & A -> 0 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left),
// !A & A -> 0 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_zero(
&info.get_data_type(&left)?,
)?))
}
// A & !A -> 0 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_zero(
&info.get_data_type(&left)?,
)?))
}
// (..A..) & A --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
// A & (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
// A & (A | B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
Transformed::yes(*left)
}
// (A | B) & A --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
Transformed::yes(*right)
}
//
// Rules for BitwiseOr
//
// A | null -> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: BitwiseOr,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null | A -> null
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A | 0 -> A (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if is_zero(&right) => Transformed::yes(*left),
// 0 | A -> A (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if is_zero(&left) => Transformed::yes(*right),
// !A | A -> -1 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
&info.get_data_type(&left)?,
)?))
}
// A | !A -> -1 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
&info.get_data_type(&left)?,
)?))
}
// (..A..) | A --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
// A | (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
// A | (A & B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
Transformed::yes(*left)
}
// (A & B) | A --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
Transformed::yes(*right)
}
//
// Rules for BitwiseXor
//
// A ^ null -> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: BitwiseXor,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null ^ A -> null
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A ^ 0 -> A (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left),
// 0 ^ A -> A (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right),
// !A ^ A -> -1 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
&info.get_data_type(&left)?,
)?))
}
// A ^ !A -> -1 (if A not nullable)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
Transformed::yes(Expr::Literal(ScalarValue::new_negative_one(
&info.get_data_type(&left)?,
)?))
}
// (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if expr_contains(&left, &right, BitwiseXor) => {
let expr = delete_xor_in_complex_expr(&left, &right, false);
Transformed::yes(if expr == *right {
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?)
} else {
expr
})
}
// A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseXor,
right,
}) if expr_contains(&right, &left, BitwiseXor) => {
let expr = delete_xor_in_complex_expr(&right, &left, true);
Transformed::yes(if expr == *left {
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
} else {
expr
})
}
//
// Rules for BitwiseShiftRight
//
// A >> null -> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: BitwiseShiftRight,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null >> A -> null
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseShiftRight,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A >> 0 -> A (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseShiftRight,
right,
}) if is_zero(&right) => Transformed::yes(*left),
//
// Rules for BitwiseShiftRight
//
// A << null -> null
Expr::BinaryExpr(BinaryExpr {
left: _,
op: BitwiseShiftLeft,
right,
}) if is_null(&right) => Transformed::yes(*right),
// null << A -> null
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseShiftLeft,
right: _,
}) if is_null(&left) => Transformed::yes(*left),
// A << 0 -> A (even if A is null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseShiftLeft,
right,
}) if is_zero(&right) => Transformed::yes(*left),
//
// Rules for Not
//
Expr::Not(inner) => Transformed::yes(negate_clause(*inner)),
//
// Rules for Negative
//
Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)),
//
// Rules for Case
//
// CASE
// WHEN X THEN A
// WHEN Y THEN B
// ...
// ELSE Q
// END
//
// ---> (X AND A) OR (Y AND B AND NOT X) OR ... (NOT (X OR Y) AND Q)
//
// Note: the rationale for this rewrite is that the expr can then be further
// simplified using the existing rules for AND/OR
Expr::Case(Case {
expr: None,
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);
for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
.and(filter_expr.clone().not())
.and(*then);
out_expr = out_expr.or(case_expr);
filter_expr = filter_expr.or(*when);
}
if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
// Do a first pass at simplification
out_expr.rewrite(self)?
}
Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
match udf.simplify(args, info)? {
ExprSimplifyResult::Original(args) => {
Transformed::no(Expr::ScalarFunction(ScalarFunction {
func: udf,
args,
}))
}
ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
}
}
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(ref udaf),
..
}) => match (udaf.simplify(), expr) {
(Some(simplify_function), Expr::AggregateFunction(af)) => {
Transformed::yes(simplify_function(af, info)?)
}
(_, expr) => Transformed::no(expr),
},
Expr::WindowFunction(WindowFunction {
fun: WindowFunctionDefinition::WindowUDF(ref udwf),
..
}) => match (udwf.simplify(), expr) {
(Some(simplify_function), Expr::WindowFunction(wf)) => {
Transformed::yes(simplify_function(wf, info)?)
}
(_, expr) => Transformed::no(expr),
},
//
// Rules for Between
//
// a between 3 and 5 --> a >= 3 AND a <=5
// a not between 3 and 5 --> a < 3 OR a > 5
Expr::Between(between) => Transformed::yes(if between.negated {
let l = *between.expr.clone();
let r = *between.expr;
or(l.lt(*between.low), r.gt(*between.high))
} else {
and(
between.expr.clone().gt_eq(*between.low),
between.expr.lt_eq(*between.high),
)
}),
//
// Rules for regexes
//
Expr::BinaryExpr(BinaryExpr {
left,
op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch),
right,
}) => Transformed::yes(simplify_regex_expr(left, op, right)?),
// Rules for Like
Expr::Like(Like {
expr,
pattern,
negated,
escape_char: _,
case_insensitive: _,
}) if !is_null(&expr)
&& matches!(
pattern.as_ref(),
Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%"
) =>
{
Transformed::yes(lit(!negated))
}
// a is not null/unknown --> true (if a is not nullable)
Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr)
if !info.nullable(&expr)? =>
{
Transformed::yes(lit(true))
}
// a is null/unknown --> false (if a is not nullable)
Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => {
Transformed::yes(lit(false))
}
// expr IN () --> false
// expr NOT IN () --> true
Expr::InList(InList {
expr,
list,
negated,
}) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => {
Transformed::yes(lit(negated))
}
// null in (x, y, z) --> null
// null not in (x, y, z) --> null
Expr::InList(InList {
expr,
list: _,
negated: _,
}) if is_null(expr.as_ref()) => Transformed::yes(lit_bool_null()),
// expr IN ((subquery)) -> expr IN (subquery), see ##5529
Expr::InList(InList {
expr,
mut list,
negated,
}) if list.len() == 1
&& matches!(list.first(), Some(Expr::ScalarSubquery { .. })) =>
{
let Expr::ScalarSubquery(subquery) = list.remove(0) else {
unreachable!()
};
Transformed::yes(Expr::InSubquery(InSubquery::new(
expr, subquery, negated,
)))
}
// Combine multiple OR expressions into a single IN list expression if possible
//
// i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)`
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Or,
right,
}) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => {
let lhs = to_inlist(*left).unwrap();
let rhs = to_inlist(*right).unwrap();
let mut seen: HashSet<Expr> = HashSet::new();
let list = lhs
.list
.into_iter()
.chain(rhs.list)
.filter(|e| seen.insert(e.to_owned()))
.collect::<Vec<_>>();
let merged_inlist = InList {
expr: lhs.expr,
list,
negated: false,
};
Transformed::yes(Expr::InList(merged_inlist))
}

Additional context

Came up on discord ( @kavirajk) https://discord.com/channels/885562378132000778/1166447479609376850/1264325499971436594

@alamb alamb added the enhancement New feature or request label Jul 22, 2024
@alamb
Copy link
Contributor Author

alamb commented Jul 22, 2024

Here is the full example from Discord:

use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::{error::Result, prelude::*};
use datafusion_common::ToDFSchema;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::simplify::SimplifyContext;

#[tokio::main]
async fn main() -> Result<()> {
    let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?;
    let props = ExecutionProps::new();
    let context = SimplifyContext::new(&props).with_schema(schema);
    let simplifier = ExprSimplifier::new(context);

    // i + (1 + 2) => i + 3
    assert_eq!(
        simplifier.simplify(col("i") + (lit(1) + lit(2)))?,
        col("i") + lit(3)
    );

    // i + 1 + 2 => i + 3 # not true. Why not?
    assert_eq!(
        simplifier.simplify(col("i") + lit(1) + lit(2))?,
        col("i") + lit(3)
    );

    Ok(())
}

fn make_field(name: &str, data_type: DataType) -> Field {
    let nullable = false;
    Field::new(name, data_type, nullable)
}

@timsaucer
Copy link
Contributor

Off the cuff, if we go down this avenue I'd imagine wanting to check for associative, commutative, and distributive properties.

For example, (3*x + 8) + (5 + x*9) = 12x + 13 takes 5 operations on the left hand side but 3 operations on the right.

(3*x +1) * x + 7* x = 3 * x^2 + 8*x takes 5 operations on the lhs and 4 on the rhs.

In these examples, x is any expression.

Associativity by itself my not pick up on 3 * x combining with x * 9, though maybe the simplifier already handles that. Again, I've only now taking a cursory look at the simplifier.

I can imagine one approach of:

  1. distribute first
  2. commute all terms according to some ordering scheme - this should handle both multiplication and addition
  3. attempt pair wise simplification until no additional simplification can be done

@alamb
Copy link
Contributor Author

alamb commented Jul 22, 2024

It would be interesting to do some research about what other systems do for this case (e.g. Calcite and DuckDB)

I suspect we could add some good special cases related to literals (e.g. apply the rules @timsaucer is describing above)

@timsaucer
Copy link
Contributor

A couple of resources for using an existing CAS (Computer Algebra System).

SymPy is a python package, so not as useful but appears to be well built and has a good description of some of the problems of simplifying:

https://docs.sympy.org/latest/tutorials/intro-tutorial/simplification.html

SymEngine is implemented in c++ and appears to support most of the use cases I have thought about so far:

https://symengine.org/

Non-official crate to wrap SymEngine

https://crates.io/crates/symengine/0.1.0

Cas-rs is a newer CAS written in rust that appears to be under active development:

https://github.com/ElectrifyPro/cas-rs

@alamb
Copy link
Contributor Author

alamb commented Jul 23, 2024

Thank you @timsaucer -- very cool

I think there would be a tradeoff between the size of the dependency / how mature it is and additional benefit in DataFusion

So if we could implement something relatively simple (even if it duplicated existing functionality in another library) but didn't introduce a new dependency, I think that would be worth considering.

@tinfoil-knight
Copy link
Contributor

DuckDB v1.0.0 (https://shell.duckdb.org/) behaves the same as Apache Datafusion. Doesn't consider associativity in expressions with a column reference.

duckdb> EXPLAIN SELECT col0 + (1+2) FROM (VALUES (100));
┌───────────────────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│         (col0 + 3)        │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│      COLUMN_DATA_SCAN     │
└───────────────────────────┘

duckdb> EXPLAIN SELECT col0 + 1+2 FROM (VALUES (100));
┌───────────────────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│      ((col0 + 1) + 2)     │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│      COLUMN_DATA_SCAN     │
└───────────────────────────┘ 

duckdb> EXPLAIN SELECT (col0 + 1)+2 FROM (VALUES (100));
┌───────────────────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│      ((col0 + 1) + 2)     │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│      COLUMN_DATA_SCAN     │
└───────────────────────────┘

duckdb> EXPLAIN SELECT (col0 + (1+3)*7+9)+(2*5+3-4+6) FROM (VALUES (100));
┌───────────────────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│  (((col0 + 28) + 9) + 15) │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│      COLUMN_DATA_SCAN     │
└───────────────────────────┘

@tinfoil-knight
Copy link
Contributor

tinfoil-knight commented Jul 30, 2024

I tried a simpler solution by re-arranging the expressions if they match the criteria for associativity (currently only considering Plus & Multiply operators) in Simplifier.

We just need to re-arrange expressions like (i + 1) + 2 to i + (1 + 2). Rest of the expression will be resolved by the existing ConstEvaluator itself.

We might've some duplication of code due to the various positions that we need to consider:
(i + 1) + 2 ; (1 + i) + 2 ; 2 + (i + 1) ; 2 + (1 + i) etc. but I'll try to reduce it as much as possible.

> explain select column1 + 1 + 2 from values (100);
+---------------+-----------------------------------------------------------------------+
| plan_type     | plan                                                                  |
+---------------+-----------------------------------------------------------------------+
| logical_plan  | Projection: column1 + Int64(3) AS column1 + Int64(1) + Int64(2)       |
|               |   Values: (Int64(100))                                                |
| physical_plan | ProjectionExec: expr=[column1@0 + 3 as column1 + Int64(1) + Int64(2)] |
|               |   ValuesExec                                                          |
|               |                                                                       |
+---------------+-----------------------------------------------------------------------+

@alamb
Copy link
Contributor Author

alamb commented Jul 31, 2024

My biggest concern with any solution for this issue is to keep the code complexity reasonable as well as not to slow down planning (by, e.g adding some sort of expontential search algorithm)

@timsaucer
Copy link
Contributor

Besides combining literal values, is there much value to this work? That is, are there other expressions that would reasonably get combined? I haven't looked much deeper at what the simplifier does.

I completely agree about the code complexity. This is a non-trivial problem that could become a deep rabbit hole.

@alamb
Copy link
Contributor Author

alamb commented Jul 31, 2024

I think agree literal values are the major usecase that I know if

There are likely other simplifications that would benefit such as some of the logical simplifications but someone wuold have to think hard and find some examples

I think adding some simple heuristics when literals are involved might be simplest way to make some improvements

@drauschenbach
Copy link
Contributor

drauschenbach commented Nov 4, 2024

Would dropping unnecessary parenthesis in a 1st pass help?

@findepi
Copy link
Member

findepi commented Nov 5, 2024

It might be that parens are purely syntactic thing. They don't exist after we exit from parser.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants