@@ -32,6 +32,7 @@ use datafusion_common::{
3232 not_impl_err, plan_datafusion_err, plan_err, Column , DataFusionError , ExprSchema ,
3333 Result , Spans , TableReference ,
3434} ;
35+ use datafusion_expr_common:: operator:: Operator ;
3536use datafusion_expr_common:: type_coercion:: binary:: BinaryTypeCoercer ;
3637use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
3738use std:: sync:: Arc ;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284 let then_nullable = case
284285 . when_then_expr
285286 . iter ( )
287+ . filter ( |( w, t) | {
288+ // Disregard branches where we can determine statically that the predicate
289+ // is always false when the then expression would evaluate to null
290+ const_result_when_value_is_null ( w, t) . unwrap_or ( true )
291+ } )
286292 . map ( |( _, t) | t. nullable ( input_schema) )
287293 . collect :: < Result < Vec < _ > > > ( ) ?;
288294 if then_nullable. contains ( & true ) {
@@ -647,6 +653,50 @@ impl ExprSchemable for Expr {
647653 }
648654}
649655
656+ /// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+ /// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+ fn const_result_when_value_is_null ( predicate : & Expr , value : & Expr ) -> Option < bool > {
659+ match predicate {
660+ Expr :: IsNotNull ( e) => {
661+ if e. as_ref ( ) . eq ( value) {
662+ Some ( false )
663+ } else {
664+ None
665+ }
666+ }
667+ Expr :: IsNull ( e) => {
668+ if e. as_ref ( ) . eq ( value) {
669+ Some ( true )
670+ } else {
671+ None
672+ }
673+ }
674+ Expr :: Not ( e) => const_result_when_value_is_null ( e, value) . map ( |b| !b) ,
675+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => match op {
676+ Operator :: And => {
677+ let l = const_result_when_value_is_null ( left, value) ;
678+ let r = const_result_when_value_is_null ( right, value) ;
679+ match ( l, r) {
680+ ( Some ( l) , Some ( r) ) => Some ( l && r) ,
681+ ( Some ( l) , None ) => Some ( l) ,
682+ ( None , Some ( r) ) => Some ( r) ,
683+ _ => None ,
684+ }
685+ }
686+ Operator :: Or => {
687+ let l = const_result_when_value_is_null ( left, value) ;
688+ let r = const_result_when_value_is_null ( right, value) ;
689+ match ( l, r) {
690+ ( Some ( l) , Some ( r) ) => Some ( l || r) ,
691+ _ => None ,
692+ }
693+ }
694+ _ => None ,
695+ } ,
696+ _ => None ,
697+ }
698+ }
699+
650700impl Expr {
651701 /// Common method for window functions that applies type coercion
652702 /// to all arguments of the window function to check if it matches
@@ -777,7 +827,10 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777827#[ cfg( test) ]
778828mod tests {
779829 use super :: * ;
780- use crate :: { col, lit, out_ref_col_with_metadata} ;
830+ use crate :: {
831+ and, binary_expr, col, is_not_null, is_null, lit, not, or,
832+ out_ref_col_with_metadata, when,
833+ } ;
781834
782835 use datafusion_common:: { internal_err, DFSchema , HashMap , ScalarValue } ;
783836
@@ -830,6 +883,150 @@ mod tests {
830883 assert ! ( expr. nullable( & get_schema( false ) ) . unwrap( ) ) ;
831884 }
832885
886+ fn check_nullability (
887+ expr : Expr ,
888+ nullable : bool ,
889+ get_schema : fn ( bool ) -> MockExprSchema ,
890+ ) -> Result < ( ) > {
891+ assert_eq ! (
892+ expr. nullable( & get_schema( true ) ) ?,
893+ nullable,
894+ "Nullability of '{expr}' should be {nullable} when column is nullable"
895+ ) ;
896+ assert ! (
897+ !expr. nullable( & get_schema( false ) ) ?,
898+ "Nullability of '{expr}' should be false when column is not nullable"
899+ ) ;
900+ Ok ( ( ) )
901+ }
902+
903+ #[ test]
904+ fn test_case_expression_nullability ( ) -> Result < ( ) > {
905+ let get_schema = |nullable| {
906+ MockExprSchema :: new ( )
907+ . with_data_type ( DataType :: Int32 )
908+ . with_nullable ( nullable)
909+ } ;
910+
911+ check_nullability (
912+ when ( is_not_null ( col ( "foo" ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
913+ false ,
914+ get_schema,
915+ ) ?;
916+
917+ check_nullability (
918+ when ( not ( is_null ( col ( "foo" ) ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
919+ false ,
920+ get_schema,
921+ ) ?;
922+
923+ check_nullability (
924+ when ( binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) , col ( "foo" ) )
925+ . otherwise ( lit ( 0 ) ) ?,
926+ true ,
927+ get_schema,
928+ ) ?;
929+
930+ check_nullability (
931+ when (
932+ and (
933+ is_not_null ( col ( "foo" ) ) ,
934+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
935+ ) ,
936+ col ( "foo" ) ,
937+ )
938+ . otherwise ( lit ( 0 ) ) ?,
939+ false ,
940+ get_schema,
941+ ) ?;
942+
943+ check_nullability (
944+ when (
945+ and (
946+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
947+ is_not_null ( col ( "foo" ) ) ,
948+ ) ,
949+ col ( "foo" ) ,
950+ )
951+ . otherwise ( lit ( 0 ) ) ?,
952+ false ,
953+ get_schema,
954+ ) ?;
955+
956+ check_nullability (
957+ when (
958+ or (
959+ is_not_null ( col ( "foo" ) ) ,
960+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
961+ ) ,
962+ col ( "foo" ) ,
963+ )
964+ . otherwise ( lit ( 0 ) ) ?,
965+ true ,
966+ get_schema,
967+ ) ?;
968+
969+ check_nullability (
970+ when (
971+ or (
972+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
973+ is_not_null ( col ( "foo" ) ) ,
974+ ) ,
975+ col ( "foo" ) ,
976+ )
977+ . otherwise ( lit ( 0 ) ) ?,
978+ true ,
979+ get_schema,
980+ ) ?;
981+
982+ check_nullability (
983+ when (
984+ or (
985+ is_not_null ( col ( "foo" ) ) ,
986+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
987+ ) ,
988+ col ( "foo" ) ,
989+ )
990+ . otherwise ( lit ( 0 ) ) ?,
991+ true ,
992+ get_schema,
993+ ) ?;
994+
995+ check_nullability (
996+ when (
997+ or (
998+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
999+ is_not_null ( col ( "foo" ) ) ,
1000+ ) ,
1001+ col ( "foo" ) ,
1002+ )
1003+ . otherwise ( lit ( 0 ) ) ?,
1004+ true ,
1005+ get_schema,
1006+ ) ?;
1007+
1008+ check_nullability (
1009+ when (
1010+ or (
1011+ and (
1012+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
1013+ is_not_null ( col ( "foo" ) ) ,
1014+ ) ,
1015+ and (
1016+ binary_expr ( col ( "foo" ) , Operator :: Eq , col ( "bar" ) ) ,
1017+ is_not_null ( col ( "foo" ) ) ,
1018+ ) ,
1019+ ) ,
1020+ col ( "foo" ) ,
1021+ )
1022+ . otherwise ( lit ( 0 ) ) ?,
1023+ false ,
1024+ get_schema,
1025+ ) ?;
1026+
1027+ Ok ( ( ) )
1028+ }
1029+
8331030 #[ test]
8341031 fn test_inlist_nullability ( ) {
8351032 let get_schema = |nullable| {
0 commit comments