Skip to content

Commit 7f8d7cf

Browse files
committed
#17801 Improve nullability reporting of case expressions
1 parent 247450d commit 7f8d7cf

File tree

4 files changed

+449
-9
lines changed

4 files changed

+449
-9
lines changed

datafusion/core/tests/tpcds_planning.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
10541054
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
10561055
if create_physical {
10571056
let _ = state.create_physical_plan(&plan).await?;
1057+
} else {
1058+
let _ = state.optimize(&plan)?;
10581059
}
10591060
}
10601061

datafusion/expr/src/expr_fn.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ pub fn is_null(expr: Expr) -> Expr {
340340
Expr::IsNull(Box::new(expr))
341341
}
342342

343+
/// Create is not null expression
344+
pub fn is_not_null(expr: Expr) -> Expr {
345+
Expr::IsNotNull(Box::new(expr))
346+
}
347+
343348
/// Create is true expression
344349
pub fn is_true(expr: Expr) -> Expr {
345350
Expr::IsTrue(Box::new(expr))

datafusion/expr/src/expr_schema.rs

Lines changed: 198 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::{
3232
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
3333
Result, Spans, TableReference,
3434
};
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
3637
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3738
use std::sync::Arc;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284
let then_nullable = case
284285
.when_then_expr
285286
.iter()
287+
.filter(|(w, t)| {
288+
// Disregard branches where we can determine statically that the predicate
289+
// is always false when the then expression would evaluate to null
290+
const_result_when_value_is_null(w, t).unwrap_or(true)
291+
})
286292
.map(|(_, t)| t.nullable(input_schema))
287293
.collect::<Result<Vec<_>>>()?;
288294
if then_nullable.contains(&true) {
@@ -647,6 +653,50 @@ impl ExprSchemable for Expr {
647653
}
648654
}
649655

656+
/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+
/// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+
fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option<bool> {
659+
match predicate {
660+
Expr::IsNotNull(e) => {
661+
if e.as_ref().eq(value) {
662+
Some(false)
663+
} else {
664+
None
665+
}
666+
}
667+
Expr::IsNull(e) => {
668+
if e.as_ref().eq(value) {
669+
Some(true)
670+
} else {
671+
None
672+
}
673+
}
674+
Expr::Not(e) => const_result_when_value_is_null(e, value).map(|b| !b),
675+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
676+
Operator::And => {
677+
let l = const_result_when_value_is_null(left, value);
678+
let r = const_result_when_value_is_null(right, value);
679+
match (l, r) {
680+
(Some(l), Some(r)) => Some(l && r),
681+
(Some(l), None) => Some(l),
682+
(None, Some(r)) => Some(r),
683+
_ => None,
684+
}
685+
}
686+
Operator::Or => {
687+
let l = const_result_when_value_is_null(left, value);
688+
let r = const_result_when_value_is_null(right, value);
689+
match (l, r) {
690+
(Some(l), Some(r)) => Some(l || r),
691+
_ => None,
692+
}
693+
}
694+
_ => None,
695+
},
696+
_ => None,
697+
}
698+
}
699+
650700
impl Expr {
651701
/// Common method for window functions that applies type coercion
652702
/// to all arguments of the window function to check if it matches
@@ -777,7 +827,10 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777827
#[cfg(test)]
778828
mod tests {
779829
use super::*;
780-
use crate::{col, lit, out_ref_col_with_metadata};
830+
use crate::{
831+
and, binary_expr, col, is_not_null, is_null, lit, not, or,
832+
out_ref_col_with_metadata, when,
833+
};
781834

782835
use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue};
783836

@@ -830,6 +883,150 @@ mod tests {
830883
assert!(expr.nullable(&get_schema(false)).unwrap());
831884
}
832885

886+
fn check_nullability(
887+
expr: Expr,
888+
nullable: bool,
889+
get_schema: fn(bool) -> MockExprSchema,
890+
) -> Result<()> {
891+
assert_eq!(
892+
expr.nullable(&get_schema(true))?,
893+
nullable,
894+
"Nullability of '{expr}' should be {nullable} when column is nullable"
895+
);
896+
assert!(
897+
!expr.nullable(&get_schema(false))?,
898+
"Nullability of '{expr}' should be false when column is not nullable"
899+
);
900+
Ok(())
901+
}
902+
903+
#[test]
904+
fn test_case_expression_nullability() -> Result<()> {
905+
let get_schema = |nullable| {
906+
MockExprSchema::new()
907+
.with_data_type(DataType::Int32)
908+
.with_nullable(nullable)
909+
};
910+
911+
check_nullability(
912+
when(is_not_null(col("foo")), col("foo")).otherwise(lit(0))?,
913+
false,
914+
get_schema,
915+
)?;
916+
917+
check_nullability(
918+
when(not(is_null(col("foo"))), col("foo")).otherwise(lit(0))?,
919+
false,
920+
get_schema,
921+
)?;
922+
923+
check_nullability(
924+
when(binary_expr(col("foo"), Operator::Eq, lit(5)), col("foo"))
925+
.otherwise(lit(0))?,
926+
true,
927+
get_schema,
928+
)?;
929+
930+
check_nullability(
931+
when(
932+
and(
933+
is_not_null(col("foo")),
934+
binary_expr(col("foo"), Operator::Eq, lit(5)),
935+
),
936+
col("foo"),
937+
)
938+
.otherwise(lit(0))?,
939+
false,
940+
get_schema,
941+
)?;
942+
943+
check_nullability(
944+
when(
945+
and(
946+
binary_expr(col("foo"), Operator::Eq, lit(5)),
947+
is_not_null(col("foo")),
948+
),
949+
col("foo"),
950+
)
951+
.otherwise(lit(0))?,
952+
false,
953+
get_schema,
954+
)?;
955+
956+
check_nullability(
957+
when(
958+
or(
959+
is_not_null(col("foo")),
960+
binary_expr(col("foo"), Operator::Eq, lit(5)),
961+
),
962+
col("foo"),
963+
)
964+
.otherwise(lit(0))?,
965+
true,
966+
get_schema,
967+
)?;
968+
969+
check_nullability(
970+
when(
971+
or(
972+
binary_expr(col("foo"), Operator::Eq, lit(5)),
973+
is_not_null(col("foo")),
974+
),
975+
col("foo"),
976+
)
977+
.otherwise(lit(0))?,
978+
true,
979+
get_schema,
980+
)?;
981+
982+
check_nullability(
983+
when(
984+
or(
985+
is_not_null(col("foo")),
986+
binary_expr(col("foo"), Operator::Eq, lit(5)),
987+
),
988+
col("foo"),
989+
)
990+
.otherwise(lit(0))?,
991+
true,
992+
get_schema,
993+
)?;
994+
995+
check_nullability(
996+
when(
997+
or(
998+
binary_expr(col("foo"), Operator::Eq, lit(5)),
999+
is_not_null(col("foo")),
1000+
),
1001+
col("foo"),
1002+
)
1003+
.otherwise(lit(0))?,
1004+
true,
1005+
get_schema,
1006+
)?;
1007+
1008+
check_nullability(
1009+
when(
1010+
or(
1011+
and(
1012+
binary_expr(col("foo"), Operator::Eq, lit(5)),
1013+
is_not_null(col("foo")),
1014+
),
1015+
and(
1016+
binary_expr(col("foo"), Operator::Eq, col("bar")),
1017+
is_not_null(col("foo")),
1018+
),
1019+
),
1020+
col("foo"),
1021+
)
1022+
.otherwise(lit(0))?,
1023+
false,
1024+
get_schema,
1025+
)?;
1026+
1027+
Ok(())
1028+
}
1029+
8331030
#[test]
8341031
fn test_inlist_nullability() {
8351032
let get_schema = |nullable| {

0 commit comments

Comments
 (0)