Skip to content

Commit 482d0be

Browse files
committed
#17801 Improve nullability reporting of case expressions
1 parent 5bbdb7e commit 482d0be

File tree

4 files changed

+427
-9
lines changed

4 files changed

+427
-9
lines changed

datafusion/core/tests/tpcds_planning.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
10541054
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
10561055
if create_physical {
10571056
let _ = state.create_physical_plan(&plan).await?;
1057+
} else {
1058+
let _ = state.optimize(&plan)?;
10581059
}
10591060
}
10601061

datafusion/expr/src/expr_fn.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ pub fn is_null(expr: Expr) -> Expr {
340340
Expr::IsNull(Box::new(expr))
341341
}
342342

343+
/// Create is not null expression
344+
pub fn is_not_null(expr: Expr) -> Expr {
345+
Expr::IsNotNull(Box::new(expr))
346+
}
347+
343348
/// Create is true expression
344349
pub fn is_true(expr: Expr) -> Expr {
345350
Expr::IsTrue(Box::new(expr))

datafusion/expr/src/expr_schema.rs

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::{
3232
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
3333
Result, Spans, TableReference,
3434
};
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
3637
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3738
use std::sync::Arc;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284
let then_nullable = case
284285
.when_then_expr
285286
.iter()
287+
.filter(|(w, t)| {
288+
// Disregard branches where we can determine statically that the predicate
289+
// is always false when the then expression would evaluate to null
290+
const_result_when_value_is_null(w, t).unwrap_or(true)
291+
})
286292
.map(|(_, t)| t.nullable(input_schema))
287293
.collect::<Result<Vec<_>>>()?;
288294
if then_nullable.contains(&true) {
@@ -647,6 +653,38 @@ impl ExprSchemable for Expr {
647653
}
648654
}
649655

656+
/// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+
/// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+
fn const_result_when_value_is_null(predicate: &Expr, value: &Expr) -> Option<bool> {
659+
match predicate {
660+
Expr::IsNotNull(e) => if e.as_ref().eq(value) { Some(false) } else { None },
661+
Expr::IsNull(e) => if e.as_ref().eq(value) { Some(true) } else { None },
662+
Expr::Not(e) => const_result_when_value_is_null(e, value).map(|b| !b),
663+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
664+
Operator::And => {
665+
let l = const_result_when_value_is_null(left, value);
666+
let r = const_result_when_value_is_null(right, value);
667+
match (l, r) {
668+
(Some(l), Some(r)) => Some(l && r),
669+
(Some(l), None) => Some(l),
670+
(None, Some(r)) => Some(r),
671+
_ => None,
672+
}
673+
}
674+
Operator::Or => {
675+
let l = const_result_when_value_is_null(left, value);
676+
let r = const_result_when_value_is_null(right, value);
677+
match (l, r) {
678+
(Some(l), Some(r)) => Some(l || r),
679+
_ => None,
680+
}
681+
}
682+
_ => None,
683+
},
684+
_ => None,
685+
}
686+
}
687+
650688
impl Expr {
651689
/// Common method for window functions that applies type coercion
652690
/// to all arguments of the window function to check if it matches
@@ -777,7 +815,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777815
#[cfg(test)]
778816
mod tests {
779817
use super::*;
780-
use crate::{col, lit, out_ref_col_with_metadata};
818+
use crate::{and, binary_expr, col, is_not_null, is_null, lit, not, or, out_ref_col_with_metadata, when};
781819

782820
use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue};
783821

@@ -830,6 +868,153 @@ mod tests {
830868
assert!(expr.nullable(&get_schema(false)).unwrap());
831869
}
832870

871+
fn check_nullability(
872+
expr: Expr,
873+
nullable: bool,
874+
get_schema: fn(bool) -> MockExprSchema,
875+
) -> Result<()> {
876+
assert_eq!(
877+
expr.nullable(&get_schema(true))?,
878+
nullable,
879+
"Nullability of '{}' should be {} when column is nullable",
880+
expr,
881+
nullable
882+
);
883+
assert!(
884+
!expr.nullable(&get_schema(false))?,
885+
"Nullability of '{}' should be false when column is not nullable",
886+
expr
887+
);
888+
Ok(())
889+
}
890+
891+
#[test]
892+
fn test_case_expression_nullability() -> Result<()> {
893+
let get_schema = |nullable| {
894+
MockExprSchema::new()
895+
.with_data_type(DataType::Int32)
896+
.with_nullable(nullable)
897+
};
898+
899+
check_nullability(
900+
when(is_not_null(col("foo")), col("foo")).otherwise(lit(0))?,
901+
false,
902+
get_schema,
903+
)?;
904+
905+
check_nullability(
906+
when(not(is_null(col("foo"))), col("foo")).otherwise(lit(0))?,
907+
false,
908+
get_schema,
909+
)?;
910+
911+
check_nullability(
912+
when(binary_expr(col("foo"), Operator::Eq, lit(5)), col("foo"))
913+
.otherwise(lit(0))?,
914+
true,
915+
get_schema,
916+
)?;
917+
918+
check_nullability(
919+
when(
920+
and(
921+
is_not_null(col("foo")),
922+
binary_expr(col("foo"), Operator::Eq, lit(5)),
923+
),
924+
col("foo"),
925+
)
926+
.otherwise(lit(0))?,
927+
false,
928+
get_schema,
929+
)?;
930+
931+
check_nullability(
932+
when(
933+
and(
934+
binary_expr(col("foo"), Operator::Eq, lit(5)),
935+
is_not_null(col("foo")),
936+
),
937+
col("foo"),
938+
)
939+
.otherwise(lit(0))?,
940+
false,
941+
get_schema,
942+
)?;
943+
944+
check_nullability(
945+
when(
946+
or(
947+
is_not_null(col("foo")),
948+
binary_expr(col("foo"), Operator::Eq, lit(5)),
949+
),
950+
col("foo"),
951+
)
952+
.otherwise(lit(0))?,
953+
true,
954+
get_schema,
955+
)?;
956+
957+
check_nullability(
958+
when(
959+
or(
960+
binary_expr(col("foo"), Operator::Eq, lit(5)),
961+
is_not_null(col("foo")),
962+
),
963+
col("foo"),
964+
)
965+
.otherwise(lit(0))?,
966+
true,
967+
get_schema,
968+
)?;
969+
970+
check_nullability(
971+
when(
972+
or(
973+
is_not_null(col("foo")),
974+
binary_expr(col("foo"), Operator::Eq, lit(5)),
975+
),
976+
col("foo"),
977+
)
978+
.otherwise(lit(0))?,
979+
true,
980+
get_schema,
981+
)?;
982+
983+
check_nullability(
984+
when(
985+
or(
986+
binary_expr(col("foo"), Operator::Eq, lit(5)),
987+
is_not_null(col("foo")),
988+
),
989+
col("foo"),
990+
)
991+
.otherwise(lit(0))?,
992+
true,
993+
get_schema,
994+
)?;
995+
996+
check_nullability(
997+
when(
998+
or(
999+
and(
1000+
binary_expr(col("foo"), Operator::Eq, lit(5)),
1001+
is_not_null(col("foo")),
1002+
),
1003+
and(
1004+
binary_expr(col("foo"), Operator::Eq, col("bar")),
1005+
is_not_null(col("foo")),
1006+
),
1007+
),
1008+
col("foo"),
1009+
)
1010+
.otherwise(lit(0))?,
1011+
false,
1012+
get_schema,
1013+
)?;
1014+
1015+
Ok(())
1016+
}
1017+
8331018
#[test]
8341019
fn test_inlist_nullability() {
8351020
let get_schema = |nullable| {

0 commit comments

Comments
 (0)