@@ -60,6 +60,11 @@ enum EvalMethod {
6060 /// are literal values
6161 /// CASE WHEN condition THEN literal ELSE literal END
6262 ScalarOrScalar ,
63+ /// This is a specialization for a specific use case where we can take a fast path
64+ /// if there is just one when/then pair and both the `then` and `else` are expressions
65+ ///
66+ /// CASE WHEN condition THEN expression ELSE expression END
67+ ExpressionOrExpression ,
6368}
6469
6570/// The CASE expression is similar to a series of nested if/else and there are two forms that
@@ -149,6 +154,8 @@ impl CaseExpr {
149154 && else_expr. as_ref ( ) . unwrap ( ) . as_any ( ) . is :: < Literal > ( )
150155 {
151156 EvalMethod :: ScalarOrScalar
157+ } else if when_then_expr. len ( ) == 1 && else_expr. is_some ( ) {
158+ EvalMethod :: ExpressionOrExpression
152159 } else {
153160 EvalMethod :: NoExpression
154161 } ;
@@ -394,6 +401,43 @@ impl CaseExpr {
394401
395402 Ok ( ColumnarValue :: Array ( zip ( & when_value, & then_value, & else_) ?) )
396403 }
404+
405+ fn expr_or_expr ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
406+ let return_type = self . data_type ( & batch. schema ( ) ) ?;
407+
408+ // evalute when condition on batch
409+ let when_value = self . when_then_expr [ 0 ] . 0 . evaluate ( batch) ?;
410+ let when_value = when_value. into_array ( batch. num_rows ( ) ) ?;
411+ let when_value = as_boolean_array ( & when_value) . map_err ( |e| {
412+ DataFusionError :: Context (
413+ "WHEN expression did not return a BooleanArray" . to_string ( ) ,
414+ Box :: new ( e) ,
415+ )
416+ } ) ?;
417+
418+ // Treat 'NULL' as false value
419+ let when_value = match when_value. null_count ( ) {
420+ 0 => Cow :: Borrowed ( when_value) ,
421+ _ => Cow :: Owned ( prep_null_mask_filter ( when_value) ) ,
422+ } ;
423+
424+ let then_value = self . when_then_expr [ 0 ]
425+ . 1
426+ . evaluate_selection ( batch, & when_value) ?
427+ . into_array ( batch. num_rows ( ) ) ?;
428+
429+ // evaluate else expression on the values not covered by when_value
430+ let remainder = not ( & when_value) ?;
431+ let e = self . else_expr . as_ref ( ) . unwrap ( ) ;
432+ // keep `else_expr`'s data type and return type consistent
433+ let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) )
434+ . unwrap_or_else ( |_| Arc :: clone ( e) ) ;
435+ let else_ = expr
436+ . evaluate_selection ( batch, & remainder) ?
437+ . into_array ( batch. num_rows ( ) ) ?;
438+
439+ Ok ( ColumnarValue :: Array ( zip ( & remainder, & else_, & then_value) ?) )
440+ }
397441}
398442
399443impl PhysicalExpr for CaseExpr {
@@ -457,6 +501,7 @@ impl PhysicalExpr for CaseExpr {
457501 self . case_column_or_null ( batch)
458502 }
459503 EvalMethod :: ScalarOrScalar => self . scalar_or_scalar ( batch) ,
504+ EvalMethod :: ExpressionOrExpression => self . expr_or_expr ( batch) ,
460505 }
461506 }
462507
@@ -1174,6 +1219,45 @@ mod tests {
11741219 Ok ( ( ) )
11751220 }
11761221
1222+ #[ test]
1223+ fn test_expr_or_expr_specialization ( ) -> Result < ( ) > {
1224+ let batch = case_test_batch1 ( ) ?;
1225+ let schema = batch. schema ( ) ;
1226+ let when = binary (
1227+ col ( "a" , & schema) ?,
1228+ Operator :: LtEq ,
1229+ lit ( 2i32 ) ,
1230+ & batch. schema ( ) ,
1231+ ) ?;
1232+ let then = binary (
1233+ col ( "a" , & schema) ?,
1234+ Operator :: Plus ,
1235+ lit ( 1i32 ) ,
1236+ & batch. schema ( ) ,
1237+ ) ?;
1238+ let else_expr = binary (
1239+ col ( "a" , & schema) ?,
1240+ Operator :: Minus ,
1241+ lit ( 1i32 ) ,
1242+ & batch. schema ( ) ,
1243+ ) ?;
1244+ let expr = CaseExpr :: try_new ( None , vec ! [ ( when, then) ] , Some ( else_expr) ) ?;
1245+ assert ! ( matches!(
1246+ expr. eval_method,
1247+ EvalMethod :: ExpressionOrExpression
1248+ ) ) ;
1249+ let result = expr
1250+ . evaluate ( & batch) ?
1251+ . into_array ( batch. num_rows ( ) )
1252+ . expect ( "Failed to convert to array" ) ;
1253+ let result = as_int32_array ( & result) . expect ( "failed to downcast to Int32Array" ) ;
1254+
1255+ let expected = & Int32Array :: from ( vec ! [ Some ( 2 ) , Some ( 1 ) , None , Some ( 4 ) ] ) ;
1256+
1257+ assert_eq ! ( expected, result) ;
1258+ Ok ( ( ) )
1259+ }
1260+
11771261 fn make_col ( name : & str , index : usize ) -> Arc < dyn PhysicalExpr > {
11781262 Arc :: new ( Column :: new ( name, index) )
11791263 }
0 commit comments