@@ -28,6 +28,7 @@ use crate::udf::ReturnFieldArgs;
2828use crate :: { utils, LogicalPlan , Projection , Subquery , WindowFunctionDefinition } ;
2929use arrow:: compute:: can_cast_types;
3030use arrow:: datatypes:: { DataType , Field , FieldRef } ;
31+ use datafusion_common:: tree_node:: TreeNode ;
3132use datafusion_common:: {
3233 not_impl_err, plan_datafusion_err, plan_err, Column , DataFusionError , ExprSchema ,
3334 Result , Spans , TableReference ,
@@ -827,10 +828,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subq
827828#[ cfg( test) ]
828829mod tests {
829830 use super :: * ;
830- use crate :: {
831- and, binary_expr, col, is_not_null, is_null, lit, not, or,
832- out_ref_col_with_metadata, when,
833- } ;
831+ use crate :: { and, col, lit, not, or, out_ref_col_with_metadata, when} ;
834832
835833 use datafusion_common:: { internal_err, DFSchema , HashMap , ScalarValue } ;
836834
@@ -883,146 +881,88 @@ mod tests {
883881 assert ! ( expr. nullable( & get_schema( false ) ) . unwrap( ) ) ;
884882 }
885883
886- fn check_nullability (
887- expr : Expr ,
888- nullable : bool ,
889- get_schema : fn ( bool ) -> MockExprSchema ,
890- ) -> Result < ( ) > {
884+ fn assert_nullability ( expr : & Expr , schema : & dyn ExprSchema , nullable : bool ) {
891885 assert_eq ! (
892- expr. nullable( & get_schema ( true ) ) ? ,
886+ expr. nullable( schema ) . unwrap ( ) ,
893887 nullable,
894- "Nullability of '{expr}' should be {nullable} when column is nullable"
895- ) ;
896- assert ! (
897- !expr. nullable( & get_schema( false ) ) ?,
898- "Nullability of '{expr}' should be false when column is not nullable"
888+ "Nullability of '{expr}' should be {nullable}"
899889 ) ;
900- Ok ( ( ) )
890+ }
891+
892+ fn assert_not_nullable ( expr : & Expr , schema : & dyn ExprSchema ) {
893+ assert_nullability ( expr, schema, false ) ;
894+ }
895+
896+ fn assert_nullable ( expr : & Expr , schema : & dyn ExprSchema ) {
897+ assert_nullability ( expr, schema, true ) ;
901898 }
902899
903900 #[ test]
904901 fn test_case_expression_nullability ( ) -> Result < ( ) > {
905- let get_schema = |nullable| {
906- MockExprSchema :: new ( )
907- . with_data_type ( DataType :: Int32 )
908- . with_nullable ( nullable)
909- } ;
902+ let nullable_schema = MockExprSchema :: new ( )
903+ . with_data_type ( DataType :: Int32 )
904+ . with_nullable ( true ) ;
910905
911- check_nullability (
912- when ( is_not_null ( col ( "foo" ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
913- false ,
914- get_schema,
915- ) ?;
916-
917- check_nullability (
918- when ( not ( is_null ( col ( "foo" ) ) ) , col ( "foo" ) ) . otherwise ( lit ( 0 ) ) ?,
919- false ,
920- get_schema,
921- ) ?;
922-
923- check_nullability (
924- when ( binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) , col ( "foo" ) )
925- . otherwise ( lit ( 0 ) ) ?,
926- true ,
927- get_schema,
928- ) ?;
929-
930- check_nullability (
931- when (
932- and (
933- is_not_null ( col ( "foo" ) ) ,
934- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
935- ) ,
936- col ( "foo" ) ,
937- )
938- . otherwise ( lit ( 0 ) ) ?,
939- false ,
940- get_schema,
941- ) ?;
942-
943- check_nullability (
944- when (
945- and (
946- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
947- is_not_null ( col ( "foo" ) ) ,
948- ) ,
949- col ( "foo" ) ,
950- )
951- . otherwise ( lit ( 0 ) ) ?,
952- false ,
953- get_schema,
954- ) ?;
955-
956- check_nullability (
957- when (
958- or (
959- is_not_null ( col ( "foo" ) ) ,
960- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
961- ) ,
962- col ( "foo" ) ,
963- )
964- . otherwise ( lit ( 0 ) ) ?,
965- true ,
966- get_schema,
967- ) ?;
968-
969- check_nullability (
970- when (
971- or (
972- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
973- is_not_null ( col ( "foo" ) ) ,
974- ) ,
975- col ( "foo" ) ,
976- )
977- . otherwise ( lit ( 0 ) ) ?,
978- true ,
979- get_schema,
980- ) ?;
981-
982- check_nullability (
983- when (
984- or (
985- is_not_null ( col ( "foo" ) ) ,
986- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
987- ) ,
988- col ( "foo" ) ,
989- )
990- . otherwise ( lit ( 0 ) ) ?,
991- true ,
992- get_schema,
993- ) ?;
994-
995- check_nullability (
996- when (
997- or (
998- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
999- is_not_null ( col ( "foo" ) ) ,
1000- ) ,
1001- col ( "foo" ) ,
1002- )
1003- . otherwise ( lit ( 0 ) ) ?,
1004- true ,
1005- get_schema,
1006- ) ?;
1007-
1008- check_nullability (
1009- when (
1010- or (
1011- and (
1012- binary_expr ( col ( "foo" ) , Operator :: Eq , lit ( 5 ) ) ,
1013- is_not_null ( col ( "foo" ) ) ,
1014- ) ,
1015- and (
1016- binary_expr ( col ( "foo" ) , Operator :: Eq , col ( "bar" ) ) ,
1017- is_not_null ( col ( "foo" ) ) ,
1018- ) ,
1019- ) ,
1020- col ( "foo" ) ,
1021- )
1022- . otherwise ( lit ( 0 ) ) ?,
1023- false ,
1024- get_schema,
1025- ) ?;
906+ let not_nullable_schema = MockExprSchema :: new ( )
907+ . with_data_type ( DataType :: Int32 )
908+ . with_nullable ( false ) ;
909+
910+ // CASE WHEN x IS NOT NULL THEN x ELSE 0
911+ let e1 = when ( col ( "x" ) . is_not_null ( ) , col ( "x" ) ) . otherwise ( lit ( 0 ) ) ?;
912+ assert_not_nullable ( & e1, & nullable_schema) ;
913+ assert_not_nullable ( & e1, & not_nullable_schema) ;
914+
915+ // CASE WHEN NOT x IS NULL THEN x ELSE 0
916+ let e2 = when ( not ( col ( "x" ) . is_null ( ) ) , col ( "x" ) ) . otherwise ( lit ( 0 ) ) ?;
917+ assert_not_nullable ( & e2, & nullable_schema) ;
918+ assert_not_nullable ( & e2, & not_nullable_schema) ;
919+
920+ // CASE WHEN X = 5 THEN x ELSE 0
921+ let e3 = when ( col ( "x" ) . eq ( lit ( 5 ) ) , col ( "x" ) ) . otherwise ( lit ( 0 ) ) ?;
922+ assert_nullable ( & e3, & nullable_schema) ;
923+ assert_not_nullable ( & e3, & not_nullable_schema) ;
924+
925+ // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0
926+ let e4 = when ( and ( col ( "x" ) . is_not_null ( ) , col ( "x" ) . eq ( lit ( 5 ) ) ) , col ( "x" ) )
927+ . otherwise ( lit ( 0 ) ) ?;
928+ assert_not_nullable ( & e4, & nullable_schema) ;
929+ assert_not_nullable ( & e4, & not_nullable_schema) ;
930+
931+ // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0
932+ let e5 = when ( and ( col ( "x" ) . eq ( lit ( 5 ) ) , col ( "x" ) . is_not_null ( ) ) , col ( "x" ) )
933+ . otherwise ( lit ( 0 ) ) ?;
934+ assert_not_nullable ( & e5, & nullable_schema) ;
935+ assert_not_nullable ( & e5, & not_nullable_schema) ;
936+
937+ // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0
938+ let e6 = when ( or ( col ( "x" ) . is_not_null ( ) , col ( "x" ) . eq ( lit ( 5 ) ) ) , col ( "x" ) )
939+ . otherwise ( lit ( 0 ) ) ?;
940+ assert_nullable ( & e6, & nullable_schema) ;
941+ assert_not_nullable ( & e6, & not_nullable_schema) ;
942+
943+ // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0
944+ let e7 = when ( or ( col ( "x" ) . eq ( lit ( 5 ) ) , col ( "x" ) . is_not_null ( ) ) , col ( "x" ) )
945+ . otherwise ( lit ( 0 ) ) ?;
946+ assert_nullable ( & e7, & nullable_schema) ;
947+ assert_not_nullable ( & e7, & not_nullable_schema) ;
948+
949+ // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0
950+ let e8 = when (
951+ or (
952+ and ( col ( "x" ) . eq ( lit ( 5 ) ) , col ( "x" ) . is_not_null ( ) ) ,
953+ and ( col ( "x" ) . eq ( col ( "bar" ) ) , col ( "x" ) . is_not_null ( ) ) ,
954+ ) ,
955+ col ( "x" ) ,
956+ )
957+ . otherwise ( lit ( 0 ) ) ?;
958+ assert_not_nullable ( & e8, & nullable_schema) ;
959+ assert_not_nullable ( & e8, & not_nullable_schema) ;
960+
961+ // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0
962+ let e9 = when ( or ( col ( "x" ) . eq ( lit ( 5 ) ) , col ( "x" ) . is_null ( ) ) , col ( "x" ) )
963+ . otherwise ( lit ( 0 ) ) ?;
964+ assert_nullable ( & e9, & nullable_schema) ;
965+ assert_not_nullable ( & e9, & not_nullable_schema) ;
1026966
1027967 Ok ( ( ) )
1028968 }
0 commit comments