Skip to content

Commit 03173f0

Browse files
committed
Simplify comparisons and binary operations involving NULL
There were optimizations simplifying some arithmetic operations (`*`, `/`, `%` and some bitwise operations) when one operand is constant `NULL`. This can be extended to almost all other binary operators.
1 parent f0630fb commit 03173f0

File tree

10 files changed

+142
-191
lines changed

10 files changed

+142
-191
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 78 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,37 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
773773

774774
let info = self.info;
775775
Ok(match expr {
776+
// `value op NULL` -> `NULL`
777+
// `NULL op value` -> `NULL`
778+
// except for few operators that can return non-null value even when one of the operands is NULL
779+
ref expr @ Expr::BinaryExpr(BinaryExpr {
780+
ref left,
781+
ref op,
782+
ref right,
783+
}) if binary_op_null_on_null(*op) && (is_null(left.as_ref()) || is_null(right.as_ref())) => {
784+
Transformed::yes(Expr::Literal(
785+
ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
786+
None,
787+
))
788+
}
789+
790+
// `value OR NULL` -> `value`
791+
// `NULL OR value` -> `value`
792+
Expr::BinaryExpr(BinaryExpr {
793+
left,
794+
op: Or,
795+
right,
796+
}) if is_null(&left) || is_null(&right) => {
797+
let left_is_null = is_null(&left);
798+
if left_is_null && is_null(&right) {
799+
Transformed::yes(lit_bool_null())
800+
} else if left_is_null {
801+
Transformed::yes(*right)
802+
} else {
803+
Transformed::yes(*left)
804+
}
805+
}
806+
776807
//
777808
// Rules for Eq
778809
//
@@ -1048,14 +1079,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
10481079
}) if is_one(&right) => {
10491080
simplify_right_is_one_case(info, left, &Multiply, &right)?
10501081
}
1051-
// A * null --> null
1052-
Expr::BinaryExpr(BinaryExpr {
1053-
left,
1054-
op: Multiply,
1055-
right,
1056-
}) if is_null(&right) => {
1057-
simplify_right_is_null_case(info, &left, &Multiply, right)?
1058-
}
10591082
// 1 * A --> A
10601083
Expr::BinaryExpr(BinaryExpr {
10611084
left,
@@ -1065,14 +1088,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
10651088
// 1 * A is equivalent to A * 1
10661089
simplify_right_is_one_case(info, right, &Multiply, &left)?
10671090
}
1068-
// null * A --> null
1069-
Expr::BinaryExpr(BinaryExpr {
1070-
left,
1071-
op: Multiply,
1072-
right,
1073-
}) if is_null(&left) => {
1074-
simplify_right_is_null_case(info, &right, &Multiply, left)?
1075-
}
10761091

10771092
// A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
10781093
Expr::BinaryExpr(BinaryExpr {
@@ -1109,37 +1124,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
11091124
}) if is_one(&right) => {
11101125
simplify_right_is_one_case(info, left, &Divide, &right)?
11111126
}
1112-
// A / null --> null
1113-
Expr::BinaryExpr(BinaryExpr {
1114-
left,
1115-
op: Divide,
1116-
right,
1117-
}) if is_null(&right) => {
1118-
simplify_right_is_null_case(info, &left, &Divide, right)?
1119-
}
1120-
// null / A --> null
1121-
Expr::BinaryExpr(BinaryExpr {
1122-
left,
1123-
op: Divide,
1124-
right,
1125-
}) if is_null(&left) => simplify_null_div_other_case(info, left, &right)?,
11261127

11271128
//
11281129
// Rules for Modulo
11291130
//
11301131

1131-
// A % null --> null
1132-
Expr::BinaryExpr(BinaryExpr {
1133-
left: _,
1134-
op: Modulo,
1135-
right,
1136-
}) if is_null(&right) => Transformed::yes(*right),
1137-
// null % A --> null
1138-
Expr::BinaryExpr(BinaryExpr {
1139-
left,
1140-
op: Modulo,
1141-
right: _,
1142-
}) if is_null(&left) => Transformed::yes(*left),
11431132
// A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
11441133
Expr::BinaryExpr(BinaryExpr {
11451134
left,
@@ -1159,20 +1148,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
11591148
// Rules for BitwiseAnd
11601149
//
11611150

1162-
// A & null -> null
1163-
Expr::BinaryExpr(BinaryExpr {
1164-
left: _,
1165-
op: BitwiseAnd,
1166-
right,
1167-
}) if is_null(&right) => Transformed::yes(*right),
1168-
1169-
// null & A -> null
1170-
Expr::BinaryExpr(BinaryExpr {
1171-
left,
1172-
op: BitwiseAnd,
1173-
right: _,
1174-
}) if is_null(&left) => Transformed::yes(*left),
1175-
11761151
// A & 0 -> 0 (if A not nullable)
11771152
Expr::BinaryExpr(BinaryExpr {
11781153
left,
@@ -1247,20 +1222,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
12471222
// Rules for BitwiseOr
12481223
//
12491224

1250-
// A | null -> null
1251-
Expr::BinaryExpr(BinaryExpr {
1252-
left: _,
1253-
op: BitwiseOr,
1254-
right,
1255-
}) if is_null(&right) => Transformed::yes(*right),
1256-
1257-
// null | A -> null
1258-
Expr::BinaryExpr(BinaryExpr {
1259-
left,
1260-
op: BitwiseOr,
1261-
right: _,
1262-
}) if is_null(&left) => Transformed::yes(*left),
1263-
12641225
// A | 0 -> A (even if A is null)
12651226
Expr::BinaryExpr(BinaryExpr {
12661227
left,
@@ -1335,20 +1296,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
13351296
// Rules for BitwiseXor
13361297
//
13371298

1338-
// A ^ null -> null
1339-
Expr::BinaryExpr(BinaryExpr {
1340-
left: _,
1341-
op: BitwiseXor,
1342-
right,
1343-
}) if is_null(&right) => Transformed::yes(*right),
1344-
1345-
// null ^ A -> null
1346-
Expr::BinaryExpr(BinaryExpr {
1347-
left,
1348-
op: BitwiseXor,
1349-
right: _,
1350-
}) if is_null(&left) => Transformed::yes(*left),
1351-
13521299
// A ^ 0 -> A (if A not nullable)
13531300
Expr::BinaryExpr(BinaryExpr {
13541301
left,
@@ -1425,20 +1372,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14251372
// Rules for BitwiseShiftRight
14261373
//
14271374

1428-
// A >> null -> null
1429-
Expr::BinaryExpr(BinaryExpr {
1430-
left: _,
1431-
op: BitwiseShiftRight,
1432-
right,
1433-
}) if is_null(&right) => Transformed::yes(*right),
1434-
1435-
// null >> A -> null
1436-
Expr::BinaryExpr(BinaryExpr {
1437-
left,
1438-
op: BitwiseShiftRight,
1439-
right: _,
1440-
}) if is_null(&left) => Transformed::yes(*left),
1441-
14421375
// A >> 0 -> A (even if A is null)
14431376
Expr::BinaryExpr(BinaryExpr {
14441377
left,
@@ -1450,20 +1383,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14501383
// Rules for BitwiseShiftRight
14511384
//
14521385

1453-
// A << null -> null
1454-
Expr::BinaryExpr(BinaryExpr {
1455-
left: _,
1456-
op: BitwiseShiftLeft,
1457-
right,
1458-
}) if is_null(&right) => Transformed::yes(*right),
1459-
1460-
// null << A -> null
1461-
Expr::BinaryExpr(BinaryExpr {
1462-
left,
1463-
op: BitwiseShiftLeft,
1464-
right: _,
1465-
}) if is_null(&left) => Transformed::yes(*left),
1466-
14671386
// A << 0 -> A (even if A is null)
14681387
Expr::BinaryExpr(BinaryExpr {
14691388
left,
@@ -1947,6 +1866,53 @@ fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
19471866
iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
19481867
}
19491868

1869+
fn binary_op_null_on_null(op: Operator) -> bool {
1870+
match op {
1871+
Operator::Eq
1872+
| Operator::NotEq
1873+
| Operator::Lt
1874+
| Operator::LtEq
1875+
| Operator::Gt
1876+
| Operator::GtEq
1877+
| Operator::Plus
1878+
| Operator::Minus
1879+
| Operator::Multiply
1880+
| Operator::Divide
1881+
| Operator::Modulo
1882+
| Operator::And
1883+
| Operator::RegexMatch
1884+
| Operator::RegexIMatch
1885+
| Operator::RegexNotMatch
1886+
| Operator::RegexNotIMatch
1887+
| Operator::LikeMatch
1888+
| Operator::ILikeMatch
1889+
| Operator::NotLikeMatch
1890+
| Operator::NotILikeMatch
1891+
| Operator::BitwiseAnd
1892+
| Operator::BitwiseOr
1893+
| Operator::BitwiseXor
1894+
| Operator::BitwiseShiftRight
1895+
| Operator::BitwiseShiftLeft
1896+
| Operator::AtArrow
1897+
| Operator::ArrowAt
1898+
| Operator::Arrow
1899+
| Operator::LongArrow
1900+
| Operator::HashArrow
1901+
| Operator::HashLongArrow
1902+
| Operator::AtAt
1903+
| Operator::IntegerDivide
1904+
| Operator::HashMinus
1905+
| Operator::AtQuestion
1906+
| Operator::Question
1907+
| Operator::QuestionAnd
1908+
| Operator::QuestionPipe => true,
1909+
Operator::Or
1910+
| Operator::IsDistinctFrom
1911+
| Operator::IsNotDistinctFrom
1912+
| Operator::StringConcat => false,
1913+
}
1914+
}
1915+
19501916
// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
19511917
fn are_inlist_and_eq_and_match_neg(
19521918
left: &Expr,
@@ -2111,58 +2077,6 @@ fn simplify_right_is_one_case<S: SimplifyInfo>(
21112077
}
21122078
}
21132079

2114-
// A * null -> null
2115-
// A / null -> null
2116-
//
2117-
// Move this function body out of the large match branch avoid stack overflow
2118-
fn simplify_right_is_null_case<S: SimplifyInfo>(
2119-
info: &S,
2120-
left: &Expr,
2121-
op: &Operator,
2122-
right: Box<Expr>,
2123-
) -> Result<Transformed<Expr>> {
2124-
// Check if resulting type would be different due to coercion
2125-
let left_type = info.get_data_type(left)?;
2126-
let right_type = info.get_data_type(&right)?;
2127-
match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2128-
Ok(result_type) => {
2129-
// Only cast if the types differ
2130-
if right_type != result_type {
2131-
Ok(Transformed::yes(Expr::Cast(Cast::new(right, result_type))))
2132-
} else {
2133-
Ok(Transformed::yes(*right))
2134-
}
2135-
}
2136-
Err(_) => Ok(Transformed::yes(*right)),
2137-
}
2138-
}
2139-
2140-
// null / A --> null
2141-
//
2142-
// Move this function body out of the large match branch avoid stack overflow
2143-
fn simplify_null_div_other_case<S: SimplifyInfo>(
2144-
info: &S,
2145-
left: Box<Expr>,
2146-
right: &Expr,
2147-
) -> Result<Transformed<Expr>> {
2148-
// Check if resulting type would be different due to coercion
2149-
let left_type = info.get_data_type(&left)?;
2150-
let right_type = info.get_data_type(right)?;
2151-
match BinaryTypeCoercer::new(&left_type, &Operator::Divide, &right_type)
2152-
.get_result_type()
2153-
{
2154-
Ok(result_type) => {
2155-
// Only cast if the types differ
2156-
if left_type != result_type {
2157-
Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2158-
} else {
2159-
Ok(Transformed::yes(*left))
2160-
}
2161-
}
2162-
Err(_) => Ok(Transformed::yes(*left)),
2163-
}
2164-
}
2165-
21662080
#[cfg(test)]
21672081
mod tests {
21682082
use super::*;

datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ mod tests {
297297
let expected = col("c2").eq(lit(16i64));
298298
assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
299299

300-
// cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL)
300+
// cast(c1, INT64) < INT64(NULL) => NULL
301301
let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
302-
let expected = col("c1").lt(null_i32());
302+
let expected = null_bool();
303303
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
304304

305305
// cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL)
@@ -317,9 +317,9 @@ mod tests {
317317
let expected = col("c1").not_eq(lit(123i32));
318318
assert_eq!(optimize_test(expr_input, &schema), expected);
319319

320-
// cast(c1, UTF8) = NULL => c1 = NULL
320+
// cast(c1, UTF8) = NULL => NULL
321321
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
322-
let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
322+
let expected = null_bool();
323323
assert_eq!(optimize_test(expr_input, &schema), expected);
324324
}
325325

@@ -422,7 +422,7 @@ mod tests {
422422

423423
// c3 < INT64(NULL)
424424
let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
425-
let expected = col("c3").lt(null_decimal(18, 2));
425+
let expected = null_bool();
426426
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
427427

428428
// decimal to decimal
@@ -653,10 +653,6 @@ mod tests {
653653
lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
654654
}
655655

656-
fn null_decimal(precision: u8, scale: i8) -> Expr {
657-
lit(ScalarValue::Decimal128(None, precision, scale))
658-
}
659-
660656
fn timestamp_nano_none_type() -> DataType {
661657
DataType::Timestamp(TimeUnit::Nanosecond, None)
662658
}

datafusion/sqllogictest/test_files/aggregate_skip_partial.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ a true false true
189189
b true false true
190190
c true false false
191191
d true false false
192-
e true false NULL
192+
e true false false
193193

194194
query TBBB rowsort
195195
select v1,
@@ -404,7 +404,7 @@ a true false true
404404
b true false true
405405
c true false false
406406
d true false false
407-
e true false NULL
407+
e true false false
408408

409409
query TBBB rowsort
410410
select v1,

0 commit comments

Comments
 (0)