Skip to content

Commit 2677c27

Browse files
authored
[branch-51] Revert rewrite for coalesce, nvl and nvl2 simplification (#18567)
Note this targets the branch-51 release branch ## Which issue does this PR close? - part of #17558 - resolves #17801 in the 51 release branch ## Rationale for this change - We merged some clever rewrites for `coalesce` and `nvl2` to use `CASE` which are faster and more correct (👏 @chenkovsky @kosiew ) - However, these rewrites cause subtle schema mismatches in some cases planning (b/c the CASE simplification nullability logic can't determine the correct nullability in some cases - see #17801) - @pepijnve has some heroic efforts to fix the schema mismatch in #17813 (comment), but it is non trivial and I am worried about merging it so close to the 51 release and introducing new edge cases ## What changes are included in this PR? 1. Revert #17357 / e5dcc8c 3. Revert #17991 / ea83c26 2. Revert #18191 / 22c4214 2. Cherry-pick 6202254, a test that reproduces the schema mismatch issue (from #18536) 3. Cherry-pick 735cacf, a fix for the benchmarks that regressed due to the revert (from #17833) 4. Update datafusion-testing (see separate PR here) for extended tests (see apache/datafusion-testing#15) ## Are these changes tested? Yes I added a new test ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent ff96b3b commit 2677c27

File tree

14 files changed

+329
-276
lines changed

14 files changed

+329
-276
lines changed

datafusion-testing

Submodule datafusion-testing updated 85 files

datafusion/core/benches/sql_planner.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,6 @@ fn criterion_benchmark(c: &mut Criterion) {
477477
};
478478

479479
let raw_tpcds_sql_queries = (1..100)
480-
// skip query 75 until it is fixed
481-
// https://github.com/apache/datafusion/issues/17801
482-
.filter(|q| *q != 75)
483480
.map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap())
484481
.collect::<Vec<_>>();
485482

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -274,33 +274,6 @@ async fn test_nvl2() -> Result<()> {
274274

275275
Ok(())
276276
}
277-
278-
#[tokio::test]
279-
async fn test_nvl2_short_circuit() -> Result<()> {
280-
let expr = nvl2(
281-
col("a"),
282-
arrow_cast(lit("1"), lit("Int32")),
283-
arrow_cast(col("a"), lit("Int32")),
284-
);
285-
286-
let batches = get_batches(expr).await?;
287-
288-
assert_snapshot!(
289-
batches_to_string(&batches),
290-
@r#"
291-
+-----------------------------------------------------------------------------------+
292-
| nvl2(test.a,arrow_cast(Utf8("1"),Utf8("Int32")),arrow_cast(test.a,Utf8("Int32"))) |
293-
+-----------------------------------------------------------------------------------+
294-
| 1 |
295-
| 1 |
296-
| 1 |
297-
| 1 |
298-
+-----------------------------------------------------------------------------------+
299-
"#
300-
);
301-
302-
Ok(())
303-
}
304277
#[tokio::test]
305278
async fn test_fn_arrow_typeof() -> Result<()> {
306279
let expr = arrow_typeof(col("l"));

datafusion/core/tests/expr_api/mod.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -320,26 +320,6 @@ async fn test_create_physical_expr() {
320320
create_simplified_expr_test(lit(1i32) + lit(2i32), "3");
321321
}
322322

323-
#[test]
324-
fn test_create_physical_expr_nvl2() {
325-
let batch = &TEST_BATCH;
326-
let df_schema = DFSchema::try_from(batch.schema()).unwrap();
327-
let ctx = SessionContext::new();
328-
329-
let expect_err = |expr| {
330-
let physical_expr = ctx.create_physical_expr(expr, &df_schema).unwrap();
331-
let err = physical_expr.evaluate(batch).unwrap_err();
332-
assert!(
333-
err.to_string()
334-
.contains("nvl2 should have been simplified to case"),
335-
"unexpected error: {err:?}"
336-
);
337-
};
338-
339-
expect_err(nvl2(col("i"), lit(1i64), lit(0i64)));
340-
expect_err(nvl2(lit(1i64), col("i"), lit(0i64)));
341-
}
342-
343323
#[tokio::test]
344324
async fn test_create_physical_expr_coercion() {
345325
// create_physical_expr does apply type coercion and unwrapping in cast

datafusion/core/tests/tpcds_planning.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,10 +1051,13 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> {
10511051

10521052
for sql in &sql {
10531053
let df = ctx.sql(sql).await?;
1054-
let (state, plan) = df.into_parts();
1055-
let plan = state.optimize(&plan)?;
1056-
if create_physical {
1057-
let _ = state.create_physical_plan(&plan).await?;
1054+
// attempt to mimic planning steps
1055+
if !create_physical {
1056+
let (state, plan) = df.into_parts();
1057+
let _ = state.optimize(&plan)?;
1058+
} else {
1059+
// this is what df.execute() does internally
1060+
let _ = df.create_physical_plan().await?;
10581061
}
10591062
}
10601063

datafusion/expr/src/udf.rs

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,7 @@ impl ScalarUDF {
252252
Ok(result)
253253
}
254254

255-
/// Determines which of the arguments passed to this function are evaluated eagerly
256-
/// and which may be evaluated lazily.
257-
///
258-
/// See [ScalarUDFImpl::conditional_arguments] for more information.
259-
pub fn conditional_arguments<'a>(
260-
&self,
261-
args: &'a [Expr],
262-
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
263-
self.inner.conditional_arguments(args)
264-
}
265-
266-
/// Returns true if some of this `exprs` subexpressions may not be evaluated
267-
/// and thus any side effects (like divide by zero) may not be encountered.
268-
///
269-
/// See [ScalarUDFImpl::short_circuits] for more information.
255+
/// Get the circuits of inner implementation
270256
pub fn short_circuits(&self) -> bool {
271257
self.inner.short_circuits()
272258
}
@@ -696,42 +682,10 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
696682
///
697683
/// Setting this to true prevents certain optimizations such as common
698684
/// subexpression elimination
699-
///
700-
/// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be
701-
/// overridden to report more accurately which arguments are eagerly evaluated and which ones
702-
/// lazily.
703685
fn short_circuits(&self) -> bool {
704686
false
705687
}
706688

707-
/// Determines which of the arguments passed to this function are evaluated eagerly
708-
/// and which may be evaluated lazily.
709-
///
710-
/// If this function returns `None`, all arguments are eagerly evaluated.
711-
/// Returning `None` is a micro optimization that saves a needless `Vec`
712-
/// allocation.
713-
///
714-
/// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
715-
/// are the arguments that are always evaluated, and `lazy` are the
716-
/// arguments that may be evaluated lazily (i.e. may not be evaluated at all
717-
/// in some cases).
718-
///
719-
/// Implementations must ensure that the two returned `Vec`s are disjunct,
720-
/// and that each argument from `args` is present in one the two `Vec`s.
721-
///
722-
/// When overriding this function, [ScalarUDFImpl::short_circuits] must
723-
/// be overridden to return `true`.
724-
fn conditional_arguments<'a>(
725-
&self,
726-
args: &'a [Expr],
727-
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
728-
if self.short_circuits() {
729-
Some((vec![], args.iter().collect()))
730-
} else {
731-
None
732-
}
733-
}
734-
735689
/// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
736690
/// intervals.
737691
///
@@ -921,13 +875,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
921875
self.inner.simplify(args, info)
922876
}
923877

924-
fn conditional_arguments<'a>(
925-
&self,
926-
args: &'a [Expr],
927-
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
928-
self.inner.conditional_arguments(args)
929-
}
930-
931878
fn short_circuits(&self) -> bool {
932879
self.inner.short_circuits()
933880
}

datafusion/functions/src/core/coalesce.rs

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::{new_null_array, BooleanArray};
19+
use arrow::compute::kernels::zip::zip;
20+
use arrow::compute::{and, is_not_null, is_null};
1821
use arrow::datatypes::{DataType, Field, FieldRef};
19-
use datafusion_common::{exec_err, internal_err, plan_err, Result};
22+
use datafusion_common::{exec_err, internal_err, Result};
2023
use datafusion_expr::binary::try_type_union_resolution;
21-
use datafusion_expr::conditional_expressions::CaseBuilder;
22-
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
2324
use datafusion_expr::{
24-
ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
25+
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
2526
};
2627
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2728
use datafusion_macros::user_doc;
@@ -47,7 +48,7 @@ use std::any::Any;
4748
)]
4849
#[derive(Debug, PartialEq, Eq, Hash)]
4950
pub struct CoalesceFunc {
50-
pub(super) signature: Signature,
51+
signature: Signature,
5152
}
5253

5354
impl Default for CoalesceFunc {
@@ -94,45 +95,61 @@ impl ScalarUDFImpl for CoalesceFunc {
9495
Ok(Field::new(self.name(), return_type, nullable).into())
9596
}
9697

97-
fn simplify(
98-
&self,
99-
args: Vec<Expr>,
100-
_info: &dyn SimplifyInfo,
101-
) -> Result<ExprSimplifyResult> {
98+
/// coalesce evaluates to the first value which is not NULL
99+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100+
let args = args.args;
101+
// do not accept 0 arguments.
102102
if args.is_empty() {
103-
return plan_err!("coalesce must have at least one argument");
104-
}
105-
if args.len() == 1 {
106-
return Ok(ExprSimplifyResult::Simplified(
107-
args.into_iter().next().unwrap(),
108-
));
103+
return exec_err!(
104+
"coalesce was called with {} arguments. It requires at least 1.",
105+
args.len()
106+
);
109107
}
110108

111-
let n = args.len();
112-
let (init, last_elem) = args.split_at(n - 1);
113-
let whens = init
114-
.iter()
115-
.map(|x| x.clone().is_not_null())
116-
.collect::<Vec<_>>();
117-
let cases = init.to_vec();
118-
Ok(ExprSimplifyResult::Simplified(
119-
CaseBuilder::new(None, whens, cases, Some(Box::new(last_elem[0].clone())))
120-
.end()?,
121-
))
122-
}
123-
124-
/// coalesce evaluates to the first value which is not NULL
125-
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
126-
internal_err!("coalesce should have been simplified to case")
127-
}
128-
129-
fn conditional_arguments<'a>(
130-
&self,
131-
args: &'a [Expr],
132-
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
133-
let eager = vec![&args[0]];
134-
let lazy = args[1..].iter().collect();
135-
Some((eager, lazy))
109+
let return_type = args[0].data_type();
110+
let mut return_array = args.iter().filter_map(|x| match x {
111+
ColumnarValue::Array(array) => Some(array.len()),
112+
_ => None,
113+
});
114+
115+
if let Some(size) = return_array.next() {
116+
// start with nulls as default output
117+
let mut current_value = new_null_array(&return_type, size);
118+
let mut remainder = BooleanArray::from(vec![true; size]);
119+
120+
for arg in args {
121+
match arg {
122+
ColumnarValue::Array(ref array) => {
123+
let to_apply = and(&remainder, &is_not_null(array.as_ref())?)?;
124+
current_value = zip(&to_apply, array, &current_value)?;
125+
remainder = and(&remainder, &is_null(array)?)?;
126+
}
127+
ColumnarValue::Scalar(value) => {
128+
if value.is_null() {
129+
continue;
130+
} else {
131+
let last_value = value.to_scalar()?;
132+
current_value = zip(&remainder, &last_value, &current_value)?;
133+
break;
134+
}
135+
}
136+
}
137+
if remainder.iter().all(|x| x == Some(false)) {
138+
break;
139+
}
140+
}
141+
Ok(ColumnarValue::Array(current_value))
142+
} else {
143+
let result = args
144+
.iter()
145+
.filter_map(|x| match x {
146+
ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()),
147+
_ => None,
148+
})
149+
.next()
150+
.unwrap_or_else(|| args[0].clone());
151+
Ok(result)
152+
}
136153
}
137154

138155
fn short_circuits(&self) -> bool {

0 commit comments

Comments
 (0)