@@ -344,7 +344,16 @@ impl CaseExpr {
344344 fn case_column_or_null ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
345345 let when_expr = & self . when_then_expr [ 0 ] . 0 ;
346346 let then_expr = & self . when_then_expr [ 0 ] . 1 ;
347- if let ColumnarValue :: Array ( bit_mask) = when_expr. evaluate ( batch) ? {
347+
348+ let when_expr_value = when_expr. evaluate ( batch) ?;
349+ let when_expr_value = match when_expr_value {
350+ ColumnarValue :: Scalar ( _) => {
351+ ColumnarValue :: Array ( when_expr_value. into_array ( batch. num_rows ( ) ) ?)
352+ }
353+ other => other,
354+ } ;
355+
356+ if let ColumnarValue :: Array ( bit_mask) = when_expr_value {
348357 let bit_mask = bit_mask
349358 . as_any ( )
350359 . downcast_ref :: < BooleanArray > ( )
@@ -896,6 +905,53 @@ mod tests {
896905 Ok ( ( ) )
897906 }
898907
908+ #[ test]
909+ fn case_with_scalar_predicate ( ) -> Result < ( ) > {
910+ let batch = case_test_batch_nulls ( ) ?;
911+ let schema = batch. schema ( ) ;
912+
913+ // SELECT CASE WHEN TRUE THEN load4 END
914+ let when = lit ( true ) ;
915+ let then = col ( "load4" , & schema) ?;
916+ let expr = generate_case_when_with_type_coercion (
917+ None ,
918+ vec ! [ ( when, then) ] ,
919+ None ,
920+ schema. as_ref ( ) ,
921+ ) ?;
922+
923+ // many rows
924+ let result = expr
925+ . evaluate ( & batch) ?
926+ . into_array ( batch. num_rows ( ) )
927+ . expect ( "Failed to convert to array" ) ;
928+ let result =
929+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
930+ let expected = & Float64Array :: from ( vec ! [
931+ Some ( 1.77 ) ,
932+ None ,
933+ None ,
934+ Some ( 1.78 ) ,
935+ None ,
936+ Some ( 1.77 ) ,
937+ ] ) ;
938+ assert_eq ! ( expected, result) ;
939+
940+ // one row
941+ let expected = Float64Array :: from ( vec ! [ Some ( 1.1 ) ] ) ;
942+ let batch =
943+ RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [ Arc :: new( expected. clone( ) ) ] ) ?;
944+ let result = expr
945+ . evaluate ( & batch) ?
946+ . into_array ( batch. num_rows ( ) )
947+ . expect ( "Failed to convert to array" ) ;
948+ let result =
949+ as_float64_array ( & result) . expect ( "failed to downcast to Float64Array" ) ;
950+ assert_eq ! ( & expected, result) ;
951+
952+ Ok ( ( ) )
953+ }
954+
899955 #[ test]
900956 fn case_expr_matches_and_nulls ( ) -> Result < ( ) > {
901957 let batch = case_test_batch_nulls ( ) ?;
0 commit comments