Skip to content

Commit 4bbaa82

Browse files
committed
#17801 Improve nullability reporting of case expressions
1 parent 247450d commit 4bbaa82

File tree

4 files changed

+444
-9
lines changed

4 files changed

+444
-9
lines changed

datafusion/core/tests/tpcds_planning.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
10541054
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
10561055
if create_physical {
10571056
let _ = state.create_physical_plan(&plan).await?;
1057+
} else {
1058+
let _ = state.optimize(&plan)?;
10581059
}
10591060
}
10601061

datafusion/expr/src/expr_fn.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ pub fn is_null(expr: Expr) -> Expr {
340340
Expr::IsNull(Box::new(expr))
341341
}
342342

343+
/// Create is not null expression
344+
pub fn is_not_null(expr: Expr) -> Expr {
345+
Expr::IsNotNull(Box::new(expr))
346+
}
347+
343348
/// Create is true expression
344349
pub fn is_true(expr: Expr) -> Expr {
345350
Expr::IsTrue(Box::new(expr))

datafusion/expr/src/expr_schema.rs

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::{
3232
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
3333
Result, Spans, TableReference,
3434
};
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
3637
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3738
use std::sync::Arc;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284
let then_nullable = case
284285
.when_then_expr
285286
.iter()
287+
.filter(|(w, t)| {
288+
// Disregard branches where we can determine statically that the predicate
289+
// is always false when the then expression would evaluate to null
290+
const_result_when_value_is_null(w, t).unwrap_or(true)
291+
})
286292
.map(|(_, t)| t.nullable(input_schema))
287293
.collect::<Result<Vec<_>>>()?;
288294
if then_nullable.contains(&true) {
@@ -647,6 +653,50 @@ impl ExprSchemable for Expr {
647653
}
648654
}
649655

656+
/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+
/// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+
fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option<bool> {
659+
match predicate {
660+
Expr::IsNotNull(e) => {
661+
if e.as_ref().eq(value) {
662+
Some(false)
663+
} else {
664+
None
665+
}
666+
}
667+
Expr::IsNull(e) => {
668+
if e.as_ref().eq(value) {
669+
Some(true)
670+
} else {
671+
None
672+
}
673+
}
674+
Expr::Not(e) => const_result_when_value_is_null(e, value).map(|b| !b),
675+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
676+
Operator::And => {
677+
let l = const_result_when_value_is_null(left, value);
678+
let r = const_result_when_value_is_null(right, value);
679+
match (l, r) {
680+
(Some(l), Some(r)) => Some(l && r),
681+
(Some(l), None) => Some(l),
682+
(None, Some(r)) => Some(r),
683+
_ => None,
684+
}
685+
}
686+
Operator::Or => {
687+
let l = const_result_when_value_is_null(left, value);
688+
let r = const_result_when_value_is_null(right, value);
689+
match (l, r) {
690+
(Some(l), Some(r)) => Some(l || r),
691+
_ => None,
692+
}
693+
}
694+
_ => None,
695+
},
696+
_ => None,
697+
}
698+
}
699+
650700
impl Expr {
651701
/// Common method for window functions that applies type coercion
652702
/// to all arguments of the window function to check if it matches
@@ -777,7 +827,10 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777827
#[cfg(test)]
778828
mod tests {
779829
use super::*;
780-
use crate::{col, lit, out_ref_col_with_metadata};
830+
use crate::{
831+
and, binary_expr, col, is_not_null, is_null, lit, not, or,
832+
out_ref_col_with_metadata, when,
833+
};
781834

782835
use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue};
783836

@@ -830,6 +883,153 @@ mod tests {
830883
assert!(expr.nullable(&get_schema(false)).unwrap());
831884
}
832885

886+
fn check_nullability(
887+
expr: Expr,
888+
nullable: bool,
889+
get_schema: fn(bool) -> MockExprSchema,
890+
) -> Result<()> {
891+
assert_eq!(
892+
expr.nullable(&get_schema(true))?,
893+
nullable,
894+
"Nullability of '{}' should be {} when column is nullable",
895+
expr,
896+
nullable
897+
);
898+
assert!(
899+
!expr.nullable(&get_schema(false))?,
900+
"Nullability of '{}' should be false when column is not nullable",
901+
expr
902+
);
903+
Ok(())
904+
}
905+
906+
#[test]
907+
fn test_case_expression_nullability() -> Result<()> {
908+
let get_schema = |nullable| {
909+
MockExprSchema::new()
910+
.with_data_type(DataType::Int32)
911+
.with_nullable(nullable)
912+
};
913+
914+
check_nullability(
915+
when(is_not_null(col("foo")), col("foo")).otherwise(lit(0))?,
916+
false,
917+
get_schema,
918+
)?;
919+
920+
check_nullability(
921+
when(not(is_null(col("foo"))), col("foo")).otherwise(lit(0))?,
922+
false,
923+
get_schema,
924+
)?;
925+
926+
check_nullability(
927+
when(binary_expr(col("foo"), Operator::Eq, lit(5)), col("foo"))
928+
.otherwise(lit(0))?,
929+
true,
930+
get_schema,
931+
)?;
932+
933+
check_nullability(
934+
when(
935+
and(
936+
is_not_null(col("foo")),
937+
binary_expr(col("foo"), Operator::Eq, lit(5)),
938+
),
939+
col("foo"),
940+
)
941+
.otherwise(lit(0))?,
942+
false,
943+
get_schema,
944+
)?;
945+
946+
check_nullability(
947+
when(
948+
and(
949+
binary_expr(col("foo"), Operator::Eq, lit(5)),
950+
is_not_null(col("foo")),
951+
),
952+
col("foo"),
953+
)
954+
.otherwise(lit(0))?,
955+
false,
956+
get_schema,
957+
)?;
958+
959+
check_nullability(
960+
when(
961+
or(
962+
is_not_null(col("foo")),
963+
binary_expr(col("foo"), Operator::Eq, lit(5)),
964+
),
965+
col("foo"),
966+
)
967+
.otherwise(lit(0))?,
968+
true,
969+
get_schema,
970+
)?;
971+
972+
check_nullability(
973+
when(
974+
or(
975+
binary_expr(col("foo"), Operator::Eq, lit(5)),
976+
is_not_null(col("foo")),
977+
),
978+
col("foo"),
979+
)
980+
.otherwise(lit(0))?,
981+
true,
982+
get_schema,
983+
)?;
984+
985+
check_nullability(
986+
when(
987+
or(
988+
is_not_null(col("foo")),
989+
binary_expr(col("foo"), Operator::Eq, lit(5)),
990+
),
991+
col("foo"),
992+
)
993+
.otherwise(lit(0))?,
994+
true,
995+
get_schema,
996+
)?;
997+
998+
check_nullability(
999+
when(
1000+
or(
1001+
binary_expr(col("foo"), Operator::Eq, lit(5)),
1002+
is_not_null(col("foo")),
1003+
),
1004+
col("foo"),
1005+
)
1006+
.otherwise(lit(0))?,
1007+
true,
1008+
get_schema,
1009+
)?;
1010+
1011+
check_nullability(
1012+
when(
1013+
or(
1014+
and(
1015+
binary_expr(col("foo"), Operator::Eq, lit(5)),
1016+
is_not_null(col("foo")),
1017+
),
1018+
and(
1019+
binary_expr(col("foo"), Operator::Eq, col("bar")),
1020+
is_not_null(col("foo")),
1021+
),
1022+
),
1023+
col("foo"),
1024+
)
1025+
.otherwise(lit(0))?,
1026+
false,
1027+
get_schema,
1028+
)?;
1029+
1030+
Ok(())
1031+
}
1032+
8331033
#[test]
8341034
fn test_inlist_nullability() {
8351035
let get_schema = |nullable| {

0 commit comments

Comments
 (0)