Skip to content

Commit ff9b921

Browse files
committed
Simplify comparisons and binary operations involving NULL
There were optimizations simplifying some arithmetic operations (`*`, `/`, `%` and some bitwise operations) when one operand is constant `NULL`. This can be extended to almost all other binary operators.
1 parent f0630fb commit ff9b921

File tree

10 files changed

+144
-191
lines changed

10 files changed

+144
-191
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 80 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,39 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
773773

774774
let info = self.info;
775775
Ok(match expr {
776+
// `value op NULL` -> `NULL`
777+
// `NULL op value` -> `NULL`
778+
// except for few operators that can return non-null value even when one of the operands is NULL
779+
ref expr @ Expr::BinaryExpr(BinaryExpr {
780+
ref left,
781+
ref op,
782+
ref right,
783+
}) if binary_op_null_on_null(*op)
784+
&& (is_null(left.as_ref()) || is_null(right.as_ref())) =>
785+
{
786+
Transformed::yes(Expr::Literal(
787+
ScalarValue::try_new_null(&info.get_data_type(expr)?)?,
788+
None,
789+
))
790+
}
791+
792+
// `value OR NULL` -> `value`
793+
// `NULL OR value` -> `value`
794+
Expr::BinaryExpr(BinaryExpr {
795+
left,
796+
op: Or,
797+
right,
798+
}) if is_null(&left) || is_null(&right) => {
799+
let left_is_null = is_null(&left);
800+
if left_is_null && is_null(&right) {
801+
Transformed::yes(lit_bool_null())
802+
} else if left_is_null {
803+
Transformed::yes(*right)
804+
} else {
805+
Transformed::yes(*left)
806+
}
807+
}
808+
776809
//
777810
// Rules for Eq
778811
//
@@ -1048,14 +1081,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
10481081
}) if is_one(&right) => {
10491082
simplify_right_is_one_case(info, left, &Multiply, &right)?
10501083
}
1051-
// A * null --> null
1052-
Expr::BinaryExpr(BinaryExpr {
1053-
left,
1054-
op: Multiply,
1055-
right,
1056-
}) if is_null(&right) => {
1057-
simplify_right_is_null_case(info, &left, &Multiply, right)?
1058-
}
10591084
// 1 * A --> A
10601085
Expr::BinaryExpr(BinaryExpr {
10611086
left,
@@ -1065,14 +1090,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
10651090
// 1 * A is equivalent to A * 1
10661091
simplify_right_is_one_case(info, right, &Multiply, &left)?
10671092
}
1068-
// null * A --> null
1069-
Expr::BinaryExpr(BinaryExpr {
1070-
left,
1071-
op: Multiply,
1072-
right,
1073-
}) if is_null(&left) => {
1074-
simplify_right_is_null_case(info, &right, &Multiply, left)?
1075-
}
10761093

10771094
// A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
10781095
Expr::BinaryExpr(BinaryExpr {
@@ -1109,37 +1126,11 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
11091126
}) if is_one(&right) => {
11101127
simplify_right_is_one_case(info, left, &Divide, &right)?
11111128
}
1112-
// A / null --> null
1113-
Expr::BinaryExpr(BinaryExpr {
1114-
left,
1115-
op: Divide,
1116-
right,
1117-
}) if is_null(&right) => {
1118-
simplify_right_is_null_case(info, &left, &Divide, right)?
1119-
}
1120-
// null / A --> null
1121-
Expr::BinaryExpr(BinaryExpr {
1122-
left,
1123-
op: Divide,
1124-
right,
1125-
}) if is_null(&left) => simplify_null_div_other_case(info, left, &right)?,
11261129

11271130
//
11281131
// Rules for Modulo
11291132
//
11301133

1131-
// A % null --> null
1132-
Expr::BinaryExpr(BinaryExpr {
1133-
left: _,
1134-
op: Modulo,
1135-
right,
1136-
}) if is_null(&right) => Transformed::yes(*right),
1137-
// null % A --> null
1138-
Expr::BinaryExpr(BinaryExpr {
1139-
left,
1140-
op: Modulo,
1141-
right: _,
1142-
}) if is_null(&left) => Transformed::yes(*left),
11431134
// A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN)
11441135
Expr::BinaryExpr(BinaryExpr {
11451136
left,
@@ -1159,20 +1150,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
11591150
// Rules for BitwiseAnd
11601151
//
11611152

1162-
// A & null -> null
1163-
Expr::BinaryExpr(BinaryExpr {
1164-
left: _,
1165-
op: BitwiseAnd,
1166-
right,
1167-
}) if is_null(&right) => Transformed::yes(*right),
1168-
1169-
// null & A -> null
1170-
Expr::BinaryExpr(BinaryExpr {
1171-
left,
1172-
op: BitwiseAnd,
1173-
right: _,
1174-
}) if is_null(&left) => Transformed::yes(*left),
1175-
11761153
// A & 0 -> 0 (if A not nullable)
11771154
Expr::BinaryExpr(BinaryExpr {
11781155
left,
@@ -1247,20 +1224,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
12471224
// Rules for BitwiseOr
12481225
//
12491226

1250-
// A | null -> null
1251-
Expr::BinaryExpr(BinaryExpr {
1252-
left: _,
1253-
op: BitwiseOr,
1254-
right,
1255-
}) if is_null(&right) => Transformed::yes(*right),
1256-
1257-
// null | A -> null
1258-
Expr::BinaryExpr(BinaryExpr {
1259-
left,
1260-
op: BitwiseOr,
1261-
right: _,
1262-
}) if is_null(&left) => Transformed::yes(*left),
1263-
12641227
// A | 0 -> A (even if A is null)
12651228
Expr::BinaryExpr(BinaryExpr {
12661229
left,
@@ -1335,20 +1298,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
13351298
// Rules for BitwiseXor
13361299
//
13371300

1338-
// A ^ null -> null
1339-
Expr::BinaryExpr(BinaryExpr {
1340-
left: _,
1341-
op: BitwiseXor,
1342-
right,
1343-
}) if is_null(&right) => Transformed::yes(*right),
1344-
1345-
// null ^ A -> null
1346-
Expr::BinaryExpr(BinaryExpr {
1347-
left,
1348-
op: BitwiseXor,
1349-
right: _,
1350-
}) if is_null(&left) => Transformed::yes(*left),
1351-
13521301
// A ^ 0 -> A (if A not nullable)
13531302
Expr::BinaryExpr(BinaryExpr {
13541303
left,
@@ -1425,20 +1374,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14251374
// Rules for BitwiseShiftRight
14261375
//
14271376

1428-
// A >> null -> null
1429-
Expr::BinaryExpr(BinaryExpr {
1430-
left: _,
1431-
op: BitwiseShiftRight,
1432-
right,
1433-
}) if is_null(&right) => Transformed::yes(*right),
1434-
1435-
// null >> A -> null
1436-
Expr::BinaryExpr(BinaryExpr {
1437-
left,
1438-
op: BitwiseShiftRight,
1439-
right: _,
1440-
}) if is_null(&left) => Transformed::yes(*left),
1441-
14421377
// A >> 0 -> A (even if A is null)
14431378
Expr::BinaryExpr(BinaryExpr {
14441379
left,
@@ -1450,20 +1385,6 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
14501385
// Rules for BitwiseShiftRight
14511386
//
14521387

1453-
// A << null -> null
1454-
Expr::BinaryExpr(BinaryExpr {
1455-
left: _,
1456-
op: BitwiseShiftLeft,
1457-
right,
1458-
}) if is_null(&right) => Transformed::yes(*right),
1459-
1460-
// null << A -> null
1461-
Expr::BinaryExpr(BinaryExpr {
1462-
left,
1463-
op: BitwiseShiftLeft,
1464-
right: _,
1465-
}) if is_null(&left) => Transformed::yes(*left),
1466-
14671388
// A << 0 -> A (even if A is null)
14681389
Expr::BinaryExpr(BinaryExpr {
14691390
left,
@@ -1947,6 +1868,53 @@ fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
19471868
iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
19481869
}
19491870

1871+
fn binary_op_null_on_null(op: Operator) -> bool {
1872+
match op {
1873+
Operator::Eq
1874+
| Operator::NotEq
1875+
| Operator::Lt
1876+
| Operator::LtEq
1877+
| Operator::Gt
1878+
| Operator::GtEq
1879+
| Operator::Plus
1880+
| Operator::Minus
1881+
| Operator::Multiply
1882+
| Operator::Divide
1883+
| Operator::Modulo
1884+
| Operator::And
1885+
| Operator::RegexMatch
1886+
| Operator::RegexIMatch
1887+
| Operator::RegexNotMatch
1888+
| Operator::RegexNotIMatch
1889+
| Operator::LikeMatch
1890+
| Operator::ILikeMatch
1891+
| Operator::NotLikeMatch
1892+
| Operator::NotILikeMatch
1893+
| Operator::BitwiseAnd
1894+
| Operator::BitwiseOr
1895+
| Operator::BitwiseXor
1896+
| Operator::BitwiseShiftRight
1897+
| Operator::BitwiseShiftLeft
1898+
| Operator::AtArrow
1899+
| Operator::ArrowAt
1900+
| Operator::Arrow
1901+
| Operator::LongArrow
1902+
| Operator::HashArrow
1903+
| Operator::HashLongArrow
1904+
| Operator::AtAt
1905+
| Operator::IntegerDivide
1906+
| Operator::HashMinus
1907+
| Operator::AtQuestion
1908+
| Operator::Question
1909+
| Operator::QuestionAnd
1910+
| Operator::QuestionPipe => true,
1911+
Operator::Or
1912+
| Operator::IsDistinctFrom
1913+
| Operator::IsNotDistinctFrom
1914+
| Operator::StringConcat => false,
1915+
}
1916+
}
1917+
19501918
// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
19511919
fn are_inlist_and_eq_and_match_neg(
19521920
left: &Expr,
@@ -2111,58 +2079,6 @@ fn simplify_right_is_one_case<S: SimplifyInfo>(
21112079
}
21122080
}
21132081

2114-
// A * null -> null
2115-
// A / null -> null
2116-
//
2117-
// Move this function body out of the large match branch avoid stack overflow
2118-
fn simplify_right_is_null_case<S: SimplifyInfo>(
2119-
info: &S,
2120-
left: &Expr,
2121-
op: &Operator,
2122-
right: Box<Expr>,
2123-
) -> Result<Transformed<Expr>> {
2124-
// Check if resulting type would be different due to coercion
2125-
let left_type = info.get_data_type(left)?;
2126-
let right_type = info.get_data_type(&right)?;
2127-
match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2128-
Ok(result_type) => {
2129-
// Only cast if the types differ
2130-
if right_type != result_type {
2131-
Ok(Transformed::yes(Expr::Cast(Cast::new(right, result_type))))
2132-
} else {
2133-
Ok(Transformed::yes(*right))
2134-
}
2135-
}
2136-
Err(_) => Ok(Transformed::yes(*right)),
2137-
}
2138-
}
2139-
2140-
// null / A --> null
2141-
//
2142-
// Move this function body out of the large match branch avoid stack overflow
2143-
fn simplify_null_div_other_case<S: SimplifyInfo>(
2144-
info: &S,
2145-
left: Box<Expr>,
2146-
right: &Expr,
2147-
) -> Result<Transformed<Expr>> {
2148-
// Check if resulting type would be different due to coercion
2149-
let left_type = info.get_data_type(&left)?;
2150-
let right_type = info.get_data_type(right)?;
2151-
match BinaryTypeCoercer::new(&left_type, &Operator::Divide, &right_type)
2152-
.get_result_type()
2153-
{
2154-
Ok(result_type) => {
2155-
// Only cast if the types differ
2156-
if left_type != result_type {
2157-
Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2158-
} else {
2159-
Ok(Transformed::yes(*left))
2160-
}
2161-
}
2162-
Err(_) => Ok(Transformed::yes(*left)),
2163-
}
2164-
}
2165-
21662082
#[cfg(test)]
21672083
mod tests {
21682084
use super::*;

datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ mod tests {
297297
let expected = col("c2").eq(lit(16i64));
298298
assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
299299

300-
// cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL)
300+
// cast(c1, INT64) < INT64(NULL) => NULL
301301
let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
302-
let expected = col("c1").lt(null_i32());
302+
let expected = null_bool();
303303
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
304304

305305
// cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL)
@@ -317,9 +317,9 @@ mod tests {
317317
let expected = col("c1").not_eq(lit(123i32));
318318
assert_eq!(optimize_test(expr_input, &schema), expected);
319319

320-
// cast(c1, UTF8) = NULL => c1 = NULL
320+
// cast(c1, UTF8) = NULL => NULL
321321
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
322-
let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
322+
let expected = null_bool();
323323
assert_eq!(optimize_test(expr_input, &schema), expected);
324324
}
325325

@@ -422,7 +422,7 @@ mod tests {
422422

423423
// c3 < INT64(NULL)
424424
let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
425-
let expected = col("c3").lt(null_decimal(18, 2));
425+
let expected = null_bool();
426426
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
427427

428428
// decimal to decimal
@@ -653,10 +653,6 @@ mod tests {
653653
lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
654654
}
655655

656-
fn null_decimal(precision: u8, scale: i8) -> Expr {
657-
lit(ScalarValue::Decimal128(None, precision, scale))
658-
}
659-
660656
fn timestamp_nano_none_type() -> DataType {
661657
DataType::Timestamp(TimeUnit::Nanosecond, None)
662658
}

datafusion/sqllogictest/test_files/aggregate_skip_partial.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ a true false true
189189
b true false true
190190
c true false false
191191
d true false false
192-
e true false NULL
192+
e true false false
193193

194194
query TBBB rowsort
195195
select v1,
@@ -404,7 +404,7 @@ a true false true
404404
b true false true
405405
c true false false
406406
d true false false
407-
e true false NULL
407+
e true false false
408408

409409
query TBBB rowsort
410410
select v1,

0 commit comments

Comments
 (0)