Skip to content

Commit 0f4b8b1

Browse files
authored
Optimize CASE expression for "expr or expr" usage. (#13953)
* Apply optimization for ExprOrExpr. * Implement optimization similar to existing code. * Add sqllogictest.
1 parent 39a69f5 commit 0f4b8b1

File tree

2 files changed

+95
-0
lines changed
  • datafusion

2 files changed

+95
-0
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ enum EvalMethod {
6060
/// are literal values
6161
/// CASE WHEN condition THEN literal ELSE literal END
6262
ScalarOrScalar,
63+
/// This is a specialization for a specific use case where we can take a fast path
64+
/// if there is just one when/then pair and both the `then` and `else` are expressions
65+
///
66+
/// CASE WHEN condition THEN expression ELSE expression END
67+
ExpressionOrExpression,
6368
}
6469

6570
/// The CASE expression is similar to a series of nested if/else and there are two forms that
@@ -149,6 +154,8 @@ impl CaseExpr {
149154
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
150155
{
151156
EvalMethod::ScalarOrScalar
157+
} else if when_then_expr.len() == 1 && else_expr.is_some() {
158+
EvalMethod::ExpressionOrExpression
152159
} else {
153160
EvalMethod::NoExpression
154161
};
@@ -394,6 +401,43 @@ impl CaseExpr {
394401

395402
Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
396403
}
404+
405+
fn expr_or_expr(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
406+
let return_type = self.data_type(&batch.schema())?;
407+
408+
// evalute when condition on batch
409+
let when_value = self.when_then_expr[0].0.evaluate(batch)?;
410+
let when_value = when_value.into_array(batch.num_rows())?;
411+
let when_value = as_boolean_array(&when_value).map_err(|e| {
412+
DataFusionError::Context(
413+
"WHEN expression did not return a BooleanArray".to_string(),
414+
Box::new(e),
415+
)
416+
})?;
417+
418+
// Treat 'NULL' as false value
419+
let when_value = match when_value.null_count() {
420+
0 => Cow::Borrowed(when_value),
421+
_ => Cow::Owned(prep_null_mask_filter(when_value)),
422+
};
423+
424+
let then_value = self.when_then_expr[0]
425+
.1
426+
.evaluate_selection(batch, &when_value)?
427+
.into_array(batch.num_rows())?;
428+
429+
// evaluate else expression on the values not covered by when_value
430+
let remainder = not(&when_value)?;
431+
let e = self.else_expr.as_ref().unwrap();
432+
// keep `else_expr`'s data type and return type consistent
433+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
434+
.unwrap_or_else(|_| Arc::clone(e));
435+
let else_ = expr
436+
.evaluate_selection(batch, &remainder)?
437+
.into_array(batch.num_rows())?;
438+
439+
Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?))
440+
}
397441
}
398442

399443
impl PhysicalExpr for CaseExpr {
@@ -457,6 +501,7 @@ impl PhysicalExpr for CaseExpr {
457501
self.case_column_or_null(batch)
458502
}
459503
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
504+
EvalMethod::ExpressionOrExpression => self.expr_or_expr(batch),
460505
}
461506
}
462507

@@ -1174,6 +1219,45 @@ mod tests {
11741219
Ok(())
11751220
}
11761221

1222+
#[test]
1223+
fn test_expr_or_expr_specialization() -> Result<()> {
1224+
let batch = case_test_batch1()?;
1225+
let schema = batch.schema();
1226+
let when = binary(
1227+
col("a", &schema)?,
1228+
Operator::LtEq,
1229+
lit(2i32),
1230+
&batch.schema(),
1231+
)?;
1232+
let then = binary(
1233+
col("a", &schema)?,
1234+
Operator::Plus,
1235+
lit(1i32),
1236+
&batch.schema(),
1237+
)?;
1238+
let else_expr = binary(
1239+
col("a", &schema)?,
1240+
Operator::Minus,
1241+
lit(1i32),
1242+
&batch.schema(),
1243+
)?;
1244+
let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
1245+
assert!(matches!(
1246+
expr.eval_method,
1247+
EvalMethod::ExpressionOrExpression
1248+
));
1249+
let result = expr
1250+
.evaluate(&batch)?
1251+
.into_array(batch.num_rows())
1252+
.expect("Failed to convert to array");
1253+
let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
1254+
1255+
let expected = &Int32Array::from(vec![Some(2), Some(1), None, Some(4)]);
1256+
1257+
assert_eq!(expected, result);
1258+
Ok(())
1259+
}
1260+
11771261
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
11781262
Arc::new(Column::new(name, index))
11791263
}

datafusion/sqllogictest/test_files/case.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,14 @@ query I
224224
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
225225
----
226226
1
227+
228+
# CASE WHEN with single predicate and two non-trivial branches (expr or expr usage)
229+
query I
230+
SELECT CASE WHEN a < 5 THEN a + b ELSE b - NVL(a, 0) END FROM foo
231+
----
232+
3
233+
7
234+
1
235+
NULL
236+
NULL
237+
7

0 commit comments

Comments
 (0)