Skip to content

Commit d2e613a

Browse files
committed
#17801 Improve nullability reporting of case expressions
1 parent 5bbdb7e commit d2e613a

File tree

3 files changed

+80
-6
lines changed

3 files changed

+80
-6
lines changed

datafusion/core/tests/tpcds_planning.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
10541054
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
10561055
if create_physical {
10571056
let _ = state.create_physical_plan(&plan).await?;
1057+
} else {
1058+
let _ = state.optimize(&plan)?;
10581059
}
10591060
}
10601061

datafusion/expr/src/expr_schema.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_common::{
3232
not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
3333
Result, Spans, TableReference,
3434
};
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
3637
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3738
use std::sync::Arc;
@@ -283,6 +284,9 @@ impl ExprSchemable for Expr {
283284
let then_nullable = case
284285
.when_then_expr
285286
.iter()
287+
.filter(|(w, t)| {
288+
!always_false_when_value_is_null(w, t).unwrap_or(false)
289+
})
286290
.map(|(_, t)| t.nullable(input_schema))
287291
.collect::<Result<Vec<_>>>()?;
288292
if then_nullable.contains(&true) {
@@ -647,6 +651,35 @@ impl ExprSchemable for Expr {
647651
}
648652
}
649653

654+
fn always_false_when_value_is_null(predicate: &Expr, value: &Expr) -> Option<bool> {
655+
match predicate {
656+
Expr::IsNotNull(e) => Some(e.as_ref().eq(value)),
657+
Expr::Not(e) => always_false_when_value_is_null(e, value).map(|b| !b),
658+
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
659+
Operator::And => {
660+
let l = always_false_when_value_is_null(left, value);
661+
let r = always_false_when_value_is_null(right, value);
662+
if l.is_some() && r.is_some() {
663+
Some(l.unwrap() || r.unwrap())
664+
} else {
665+
None
666+
}
667+
}
668+
Operator::Or => {
669+
let l = always_false_when_value_is_null(left, value);
670+
let r = always_false_when_value_is_null(right, value);
671+
if l.is_some() && r.is_some() {
672+
Some(l.unwrap() && r.unwrap())
673+
} else {
674+
None
675+
}
676+
}
677+
_ => None,
678+
},
679+
_ => None,
680+
}
681+
}
682+
650683
impl Expr {
651684
/// Common method for window functions that applies type coercion
652685
/// to all arguments of the window function to check if it matches

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::expressions::try_cast;
18+
use crate::expressions::{try_cast, BinaryExpr, IsNotNullExpr, NotExpr};
1919
use crate::PhysicalExpr;
20-
use std::borrow::Cow;
21-
use std::hash::Hash;
22-
use std::{any::Any, sync::Arc};
23-
2420
use arrow::array::*;
2521
use arrow::compute::kernels::zip::zip;
2622
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
@@ -30,8 +26,13 @@ use datafusion_common::{
3026
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
3127
};
3228
use datafusion_expr::ColumnarValue;
29+
use std::borrow::Cow;
30+
use std::hash::Hash;
31+
use std::{any::Any, sync::Arc};
3332

3433
use super::{Column, Literal};
34+
use datafusion_expr_common::dyn_eq::DynEq;
35+
use datafusion_expr_common::operator::Operator;
3536
use datafusion_physical_expr_common::datum::compare_with_eq;
3637
use itertools::Itertools;
3738

@@ -481,6 +482,9 @@ impl PhysicalExpr for CaseExpr {
481482
let then_nullable = self
482483
.when_then_expr
483484
.iter()
485+
.filter(|(w, t)| {
486+
!always_false_when_value_is_null(w.as_ref(), t.as_ref()).unwrap_or(false)
487+
})
484488
.map(|(_, t)| t.nullable(input_schema))
485489
.collect::<Result<Vec<_>>>()?;
486490
if then_nullable.contains(&true) {
@@ -588,6 +592,42 @@ impl PhysicalExpr for CaseExpr {
588592
}
589593
}
590594

595+
fn always_false_when_value_is_null(
596+
predicate: &dyn PhysicalExpr,
597+
value: &dyn PhysicalExpr,
598+
) -> Option<bool> {
599+
let predicate_any = predicate.as_any();
600+
if let Some(not_null) = predicate_any.downcast_ref::<IsNotNullExpr>() {
601+
Some(not_null.arg().as_ref().dyn_eq(value))
602+
} else if let Some(not) = predicate_any.downcast_ref::<NotExpr>() {
603+
always_false_when_value_is_null(not.arg().as_ref(), value).map(|b| !b)
604+
} else if let Some(binary) = predicate_any.downcast_ref::<BinaryExpr>() {
605+
match binary.op() {
606+
Operator::And => {
607+
let l = always_false_when_value_is_null(binary.left().as_ref(), value);
608+
let r = always_false_when_value_is_null(binary.right().as_ref(), value);
609+
if l.is_some() && r.is_some() {
610+
Some(l.unwrap() || r.unwrap())
611+
} else {
612+
None
613+
}
614+
}
615+
Operator::Or => {
616+
let l = always_false_when_value_is_null(binary.left().as_ref(), value);
617+
let r = always_false_when_value_is_null(binary.right().as_ref(), value);
618+
if l.is_some() && r.is_some() {
619+
Some(l.unwrap() && r.unwrap())
620+
} else {
621+
None
622+
}
623+
}
624+
_ => None,
625+
}
626+
} else {
627+
None
628+
}
629+
}
630+
591631
/// Create a CASE expression
592632
pub fn case(
593633
expr: Option<Arc<dyn PhysicalExpr>>,

0 commit comments

Comments
 (0)