@@ -32,6 +32,7 @@ use datafusion_common::{
3232 not_impl_err, plan_datafusion_err, plan_err, Column , DataFusionError , ExprSchema ,
3333 Result , Spans , TableReference ,
3434} ;
35+ use datafusion_expr_common:: operator:: Operator ;
3536use datafusion_expr_common:: type_coercion:: binary:: BinaryTypeCoercer ;
3637use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
3738use std:: sync:: Arc ;
@@ -283,6 +284,11 @@ impl ExprSchemable for Expr {
283284 let then_nullable = case
284285 . when_then_expr
285286 . iter ( )
287+ . filter ( |( w, t) | {
288+ // Disregard branches where we can determine statically that the predicate
289+ // is always false when the then expression would evaluate to null
290+ const_result_when_value_is_null ( w, t) . unwrap_or ( true )
291+ } )
286292 . map ( |( _, t) | t. nullable ( input_schema) )
287293 . collect :: < Result < Vec < _ > > > ( ) ?;
288294 if then_nullable. contains ( & true ) {
@@ -647,6 +653,50 @@ impl ExprSchemable for Expr {
647653 }
648654}
649655
656+ /// Determines if the given `predicate` can be const evaluated if `value` were to evaluate to `NULL`.
657+ /// Returns a `Some` value containing the const result if so; otherwise returns `None`.
658+ fn const_result_when_value_is_null ( predicate : & Expr , value : & Expr ) -> Option < bool > {
659+ match predicate {
660+ Expr :: IsNotNull ( e) => {
661+ if e. as_ref ( ) . eq ( value) {
662+ Some ( false )
663+ } else {
664+ None
665+ }
666+ }
667+ Expr :: IsNull ( e) => {
668+ if e. as_ref ( ) . eq ( value) {
669+ Some ( true )
670+ } else {
671+ None
672+ }
673+ }
674+ Expr :: Not ( e) => const_result_when_value_is_null ( e, value) . map ( |b| !b) ,
675+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => match op {
676+ Operator :: And => {
677+ let l = const_result_when_value_is_null ( left, value) ;
678+ let r = const_result_when_value_is_null ( right, value) ;
679+ match ( l, r) {
680+ ( Some ( l) , Some ( r) ) => Some ( l && r) ,
681+ ( Some ( l) , None ) => Some ( l) ,
682+ ( None , Some ( r) ) => Some ( r) ,
683+ _ => None ,
684+ }
685+ }
686+ Operator :: Or => {
687+ let l = const_result_when_value_is_null ( left, value) ;
688+ let r = const_result_when_value_is_null ( right, value) ;
689+ match ( l, r) {
690+ ( Some ( l) , Some ( r) ) => Some ( l || r) ,
691+ _ => None ,
692+ }
693+ }
694+ _ => None ,
695+ } ,
696+ _ => None ,
697+ }
698+ }
699+
650700impl Expr {
651701 /// Common method for window functions that applies type coercion
652702 /// to all arguments of the window function to check if it matches
@@ -777,7 +827,10 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
777827#[ cfg( test) ]
778828mod tests {
779829 use super :: * ;
780- use crate :: { col, lit, out_ref_col_with_metadata} ;
830+ use crate :: {
831+ and, binary_expr, col, is_not_null, is_null, lit, not, or,
832+ out_ref_col_with_metadata, when,
833+ } ;
781834
782835 use datafusion_common:: { internal_err, DFSchema , HashMap , ScalarValue } ;
783836
@@ -830,6 +883,153 @@ mod tests {
830883 assert ! ( expr. nullable( & get_schema( false ) ) . unwrap( ) ) ;
831884 }
832885
886+ fn check_nullability (
887+ expr : Expr ,
888+ nullable : bool ,
889+ get_schema : fn ( bool ) -> MockExprSchema ,
890+ ) -> Result < ( ) > {
891+ assert_eq ! (
892+ expr. nullable( & get_schema( true ) ) ?,
893+ nullable,
894+ "Nullability of '{}' should be {} when column is nullable" ,
895+ expr,
896+ nullable
897+ ) ;
898+ assert ! (
899+ !expr. nullable( & get_schema( false ) ) ?,
900+ "Nullability of '{}' should be false when column is not nullable" ,
901+ expr
902+ ) ;
903+ Ok ( ( ) )
904+ }
905+
906+ #[ test]
907+ fn test_case_expression_nullability ( ) -> Result < ( ) > {
908+ let get_schema = |nullable| {
909+ MockExprSchema :: new ( )
910+ . with_data_type ( DataType :: Int32 )
911+ . with_nullable ( nullable)
912+ } ;
913+
914+ check_nullability (
915+ when ( is_not_null ( col ( "foo" ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
916+ false ,
917+ get_schema,
918+ ) ?;
919+
920+ check_nullability (
921+ when ( not ( is_null ( col ( "foo" ) ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
922+ false ,
923+ get_schema,
924+ ) ?;
925+
926+ check_nullability (
927+ when ( binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) , col ( "foo" ) )
928+ . otherwise ( lit ( 0 ) ) ?,
929+ true ,
930+ get_schema,
931+ ) ?;
932+
933+ check_nullability (
934+ when (
935+ and (
936+ is_not_null ( col ( "foo" ) ) ,
937+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
938+ ) ,
939+ col ( "foo" ) ,
940+ )
941+ . otherwise ( lit ( 0 ) ) ?,
942+ false ,
943+ get_schema,
944+ ) ?;
945+
946+ check_nullability (
947+ when (
948+ and (
949+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
950+ is_not_null ( col ( "foo" ) ) ,
951+ ) ,
952+ col ( "foo" ) ,
953+ )
954+ . otherwise ( lit ( 0 ) ) ?,
955+ false ,
956+ get_schema,
957+ ) ?;
958+
959+ check_nullability (
960+ when (
961+ or (
962+ is_not_null ( col ( "foo" ) ) ,
963+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
964+ ) ,
965+ col ( "foo" ) ,
966+ )
967+ . otherwise ( lit ( 0 ) ) ?,
968+ true ,
969+ get_schema,
970+ ) ?;
971+
972+ check_nullability (
973+ when (
974+ or (
975+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
976+ is_not_null ( col ( "foo" ) ) ,
977+ ) ,
978+ col ( "foo" ) ,
979+ )
980+ . otherwise ( lit ( 0 ) ) ?,
981+ true ,
982+ get_schema,
983+ ) ?;
984+
985+ check_nullability (
986+ when (
987+ or (
988+ is_not_null ( col ( "foo" ) ) ,
989+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
990+ ) ,
991+ col ( "foo" ) ,
992+ )
993+ . otherwise ( lit ( 0 ) ) ?,
994+ true ,
995+ get_schema,
996+ ) ?;
997+
998+ check_nullability (
999+ when (
1000+ or (
1001+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
1002+ is_not_null ( col ( "foo" ) ) ,
1003+ ) ,
1004+ col ( "foo" ) ,
1005+ )
1006+ . otherwise ( lit ( 0 ) ) ?,
1007+ true ,
1008+ get_schema,
1009+ ) ?;
1010+
1011+ check_nullability (
1012+ when (
1013+ or (
1014+ and (
1015+ binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
1016+ is_not_null ( col ( "foo" ) ) ,
1017+ ) ,
1018+ and (
1019+ binary_expr ( col ( "foo" ) , Operator :: Eq , col ( "bar" ) ) ,
1020+ is_not_null ( col ( "foo" ) ) ,
1021+ ) ,
1022+ ) ,
1023+ col ( "foo" ) ,
1024+ )
1025+ . otherwise ( lit ( 0 ) ) ?,
1026+ false ,
1027+ get_schema,
1028+ ) ?;
1029+
1030+ Ok ( ( ) )
1031+ }
1032+
8331033 #[ test]
8341034 fn test_inlist_nullability ( ) {
8351035 let get_schema = |nullable| {
0 commit comments