@@ -32,6 +32,7 @@ use datafusion_common::{
3232 not_impl_err, plan_datafusion_err, plan_err, Column , DataFusionError , ExprSchema ,
3333 Result , Spans , TableReference ,
3434} ;
35+ use datafusion_expr_common:: operator:: Operator ;
3536use datafusion_expr_common:: type_coercion:: binary:: BinaryTypeCoercer ;
3637use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
3738use std:: sync:: Arc ;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284 let then_nullable = case
284285 . when_then_expr
285286 . iter ( )
287+ . filter ( |( w, t) | {
288+ // Disregard branches where we can determine statically that the predicate
289+ // is always false when the then expression would evaluate to null
290+ const_result_when_value_is_null ( w, t) . unwrap_or ( true )
291+ } )
286292 . map ( |( _, t) | t. nullable ( input_schema) )
287293 . collect :: < Result < Vec < _ > > > ( ) ?;
288294 if then_nullable. contains ( & true ) {
@@ -647,6 +653,38 @@ impl ExprSchemable for Expr {
647653 }
648654}
649655
656+ /// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+ /// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+ fn const_result_when_value_is_null ( predicate : & Expr , value : & Expr ) -> Option < bool > {
659+ match predicate {
660+ Expr :: IsNotNull ( e) => if e. as_ref ( ) . eq ( value) { Some ( false ) } else { None } ,
661+ Expr :: IsNull ( e) => if e. as_ref ( ) . eq ( value) { Some ( true ) } else { None } ,
662+ Expr :: Not ( e) => const_result_when_value_is_null ( e, value) . map ( |b| !b) ,
663+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => match op {
664+ Operator :: And => {
665+ let l = const_result_when_value_is_null ( left, value) ;
666+ let r = const_result_when_value_is_null ( right, value) ;
667+ match ( l, r) {
668+ ( Some ( l) , Some ( r) ) => Some ( l && r) ,
669+ ( Some ( l) , None ) => Some ( l) ,
670+ ( None , Some ( r) ) => Some ( r) ,
671+ _ => None ,
672+ }
673+ }
674+ Operator :: Or => {
675+ let l = const_result_when_value_is_null ( left, value) ;
676+ let r = const_result_when_value_is_null ( right, value) ;
677+ match ( l, r) {
678+ ( Some ( l) , Some ( r) ) => Some ( l || r) ,
679+ _ => None ,
680+ }
681+ }
682+ _ => None ,
683+ } ,
684+ _ => None ,
685+ }
686+ }
687+
650688impl Expr {
651689 /// Common method for window functions that applies type coercion
652690 /// to all arguments of the window function to check if it matches
@@ -777,7 +815,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777815#[ cfg( test) ]
778816mod tests {
779817 use super :: * ;
780- use crate :: { col, lit, out_ref_col_with_metadata} ;
818+ use crate :: { and , binary_expr , col, is_not_null , is_null , lit, not , or , out_ref_col_with_metadata, when } ;
781819
782820 use datafusion_common:: { internal_err, DFSchema , HashMap , ScalarValue } ;
783821
@@ -830,6 +868,153 @@ mod tests {
830868 assert ! ( expr. nullable( & get_schema( false ) ) . unwrap( ) ) ;
831869 }
832870
871+ fn check_nullability (
872+ expr : Expr ,
873+ nullable : bool ,
874+ get_schema : fn ( bool ) -> MockExprSchema ,
875+ ) -> Result < ( ) > {
876+ assert_eq ! (
877+ expr. nullable( & get_schema( true ) ) ?,
878+ nullable,
879+ "Nullability of '{}' should be {} when column is nullable" ,
880+ expr,
881+ nullable
882+ ) ;
883+ assert ! (
884+ !expr. nullable( & get_schema( false ) ) ?,
885+ "Nullability of '{}' should be false when column is not nullable" ,
886+ expr
887+ ) ;
888+ Ok ( ( ) )
889+ }
890+
891+ #[ test]
892+ fn test_case_expression_nullability ( ) -> Result < ( ) > {
893+ let get_schema = |nullable| {
894+ MockExprSchema :: new ( )
895+ . with_data_type ( DataType :: Int32 )
896+ . with_nullable ( nullable)
897+ } ;
898+
899+ check_nullability (
900+ when ( is_not_null ( col ( "foo" ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
901+ false ,
902+ get_schema,
903+ ) ?;
904+
905+ check_nullability (
906+ when ( not ( is_null ( col ( "foo" ) ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
907+ false ,
908+ get_schema,
909+ ) ?;
910+
911+ check_nullability (
912+ when ( binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) , col ( "foo" ) )
913+ . otherwise ( lit ( 0 ) ) ?,
914+ true ,
915+ get_schema,
916+ ) ?;
917+
918+ check_nullability (
919+ when (
920+ and (
921+ is_not_null ( col ( "foo" ) ) ,
922+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
923+ ) ,
924+ col ( "foo" ) ,
925+ )
926+ . otherwise ( lit ( 0 ) ) ?,
927+ false ,
928+ get_schema,
929+ ) ?;
930+
931+ check_nullability (
932+ when (
933+ and (
934+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
935+ is_not_null ( col ( "foo" ) ) ,
936+ ) ,
937+ col ( "foo" ) ,
938+ )
939+ . otherwise ( lit ( 0 ) ) ?,
940+ false ,
941+ get_schema,
942+ ) ?;
943+
944+ check_nullability (
945+ when (
946+ or (
947+ is_not_null ( col ( "foo" ) ) ,
948+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
949+ ) ,
950+ col ( "foo" ) ,
951+ )
952+ . otherwise ( lit ( 0 ) ) ?,
953+ true ,
954+ get_schema,
955+ ) ?;
956+
957+ check_nullability (
958+ when (
959+ or (
960+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
961+ is_not_null ( col ( "foo" ) ) ,
962+ ) ,
963+ col ( "foo" ) ,
964+ )
965+ . otherwise ( lit ( 0 ) ) ?,
966+ true ,
967+ get_schema,
968+ ) ?;
969+
970+ check_nullability (
971+ when (
972+ or (
973+ is_not_null ( col ( "foo" ) ) ,
974+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
975+ ) ,
976+ col ( "foo" ) ,
977+ )
978+ . otherwise ( lit ( 0 ) ) ?,
979+ true ,
980+ get_schema,
981+ ) ?;
982+
983+ check_nullability (
984+ when (
985+ or (
986+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
987+ is_not_null ( col ( "foo" ) ) ,
988+ ) ,
989+ col ( "foo" ) ,
990+ )
991+ . otherwise ( lit ( 0 ) ) ?,
992+ true ,
993+ get_schema,
994+ ) ?;
995+
996+ check_nullability (
997+ when (
998+ or (
999+ and (
1000+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
1001+ is_not_null ( col ( "foo" ) ) ,
1002+ ) ,
1003+ and (
1004+ binary_expr ( col ( "foo" ) , Operator :: Eq , col ( "bar" ) ) ,
1005+ is_not_null ( col ( "foo" ) ) ,
1006+ ) ,
1007+ ) ,
1008+ col ( "foo" ) ,
1009+ )
1010+ . otherwise ( lit ( 0 ) ) ?,
1011+ false ,
1012+ get_schema,
1013+ ) ?;
1014+
1015+ Ok ( ( ) )
1016+ }
1017+
8331018 #[ test]
8341019 fn test_inlist_nullability ( ) {
8351020 let get_schema = |nullable| {
0 commit comments