Skip to content

Commit 9c49413

Browse files
committed
Make PruningPredicate's rewrite public
1 parent 43d0bcf commit 9c49413

File tree

1 file changed

+158
-23
lines changed

1 file changed

+158
-23
lines changed

datafusion/core/src/physical_optimizer/pruning.rs

Lines changed: 158 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,37 @@ pub struct PruningPredicate {
478478
literal_guarantees: Vec<LiteralGuarantee>,
479479
}
480480

481+
/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
482+
/// or predicates that reference columns that are not in the schema.
483+
pub trait UnhandledPredicateHook {
484+
/// Called when a predicate can not be handled by DataFusion's transformation rules
485+
/// or is referencing a column that is not in the schema.
486+
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
487+
}
488+
489+
#[derive(Debug, Clone)]
490+
struct ConstantUnhandledPredicateHook {
491+
default: Arc<dyn PhysicalExpr>,
492+
}
493+
494+
impl ConstantUnhandledPredicateHook {
495+
fn new(default: Arc<dyn PhysicalExpr>) -> Self {
496+
Self { default }
497+
}
498+
}
499+
500+
impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
501+
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
502+
self.default.clone()
503+
}
504+
}
505+
506+
fn default_unhandled_hook() -> Arc<dyn UnhandledPredicateHook> {
507+
Arc::new(ConstantUnhandledPredicateHook::new(Arc::new(
508+
phys_expr::Literal::new(ScalarValue::Boolean(Some(true))),
509+
)))
510+
}
511+
481512
impl PruningPredicate {
482513
/// Try to create a new instance of [`PruningPredicate`]
483514
///
@@ -502,10 +533,16 @@ impl PruningPredicate {
502533
/// See the struct level documentation on [`PruningPredicate`] for more
503534
/// details.
504535
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
536+
let unhandled_hook = default_unhandled_hook();
537+
505538
// build predicate expression once
506539
let mut required_columns = RequiredColumns::new();
507-
let predicate_expr =
508-
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
540+
let predicate_expr = build_predicate_expression(
541+
&expr,
542+
schema.as_ref(),
543+
&mut required_columns,
544+
&unhandled_hook,
545+
);
509546

510547
let literal_guarantees = LiteralGuarantee::analyze(&expr);
511548

@@ -1316,23 +1353,43 @@ const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
13161353
/// expression that will evaluate to FALSE if it can be determined no
13171354
/// rows between the min/max values could pass the predicates.
13181355
///
1356+
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
1357+
///
13191358
/// Returns the pruning predicate as an [`PhysicalExpr`]
13201359
///
1321-
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
1360+
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
1361+
pub fn rewrite_predicate_to_statistics_predicate(
1362+
expr: &Arc<dyn PhysicalExpr>,
1363+
schema: &Schema,
1364+
unhandled_hook: Option<Arc<dyn UnhandledPredicateHook>>,
1365+
) -> Arc<dyn PhysicalExpr> {
1366+
let unhandled_hook = unhandled_hook.unwrap_or(default_unhandled_hook());
1367+
1368+
let mut required_columns = RequiredColumns::new();
1369+
1370+
build_predicate_expression(expr, schema, &mut required_columns, &unhandled_hook)
1371+
}
1372+
1373+
/// Translate logical filter expression into pruning predicate
1374+
/// expression that will evaluate to FALSE if it can be determined no
1375+
/// rows between the min/max values could pass the predicates.
1376+
///
1377+
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
1378+
///
1379+
/// Returns the pruning predicate as an [`PhysicalExpr`]
1380+
///
1381+
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
13221382
fn build_predicate_expression(
13231383
expr: &Arc<dyn PhysicalExpr>,
13241384
schema: &Schema,
13251385
required_columns: &mut RequiredColumns,
1386+
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
13261387
) -> Arc<dyn PhysicalExpr> {
1327-
// Returned for unsupported expressions. Such expressions are
1328-
// converted to TRUE.
1329-
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));
1330-
13311388
// predicate expression can only be a binary expression
13321389
let expr_any = expr.as_any();
13331390
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
13341391
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
1335-
.unwrap_or(unhandled);
1392+
.unwrap_or_else(|| unhandled_hook.handle(expr));
13361393
}
13371394
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
13381395
return build_is_null_column_expr(
@@ -1341,19 +1398,19 @@ fn build_predicate_expression(
13411398
required_columns,
13421399
true,
13431400
)
1344-
.unwrap_or(unhandled);
1401+
.unwrap_or_else(|| unhandled_hook.handle(expr));
13451402
}
13461403
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
13471404
return build_single_column_expr(col, schema, required_columns, false)
1348-
.unwrap_or(unhandled);
1405+
.unwrap_or_else(|| unhandled_hook.handle(expr));
13491406
}
13501407
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
13511408
// match !col (don't do so recursively)
13521409
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
13531410
return build_single_column_expr(col, schema, required_columns, true)
1354-
.unwrap_or(unhandled);
1411+
.unwrap_or_else(|| unhandled_hook.handle(expr));
13551412
} else {
1356-
return unhandled;
1413+
return unhandled_hook.handle(expr);
13571414
}
13581415
}
13591416
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
@@ -1382,9 +1439,14 @@ fn build_predicate_expression(
13821439
})
13831440
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
13841441
.unwrap();
1385-
return build_predicate_expression(&change_expr, schema, required_columns);
1442+
return build_predicate_expression(
1443+
&change_expr,
1444+
schema,
1445+
required_columns,
1446+
unhandled_hook,
1447+
);
13861448
} else {
1387-
return unhandled;
1449+
return unhandled_hook.handle(expr);
13881450
}
13891451
}
13901452

@@ -1396,21 +1458,23 @@ fn build_predicate_expression(
13961458
bin_expr.right().clone(),
13971459
)
13981460
} else {
1399-
return unhandled;
1461+
return unhandled_hook.handle(expr);
14001462
}
14011463
};
14021464

14031465
if op == Operator::And || op == Operator::Or {
1404-
let left_expr = build_predicate_expression(&left, schema, required_columns);
1405-
let right_expr = build_predicate_expression(&right, schema, required_columns);
1466+
let left_expr =
1467+
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
1468+
let right_expr =
1469+
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
14061470
// simplify boolean expression if applicable
14071471
let expr = match (&left_expr, op, &right_expr) {
14081472
(left, Operator::And, _) if is_always_true(left) => right_expr,
14091473
(_, Operator::And, right) if is_always_true(right) => left_expr,
14101474
(left, Operator::Or, right)
14111475
if is_always_true(left) || is_always_true(right) =>
14121476
{
1413-
unhandled
1477+
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
14141478
}
14151479
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
14161480
};
@@ -1423,12 +1487,11 @@ fn build_predicate_expression(
14231487
Ok(builder) => builder,
14241488
// allow partial failure in predicate expression generation
14251489
// this can still produce a useful predicate when multiple conditions are joined using AND
1426-
Err(_) => {
1427-
return unhandled;
1428-
}
1490+
Err(_) => return unhandled_hook.handle(expr),
14291491
};
14301492

1431-
build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
1493+
build_statistics_expr(&mut expr_builder)
1494+
.unwrap_or_else(|_| unhandled_hook.handle(expr))
14321495
}
14331496

14341497
fn build_statistics_expr(
@@ -1582,6 +1645,8 @@ mod tests {
15821645
use arrow_array::UInt64Array;
15831646
use datafusion_expr::expr::InList;
15841647
use datafusion_expr::{cast, is_null, try_cast, Expr};
1648+
use datafusion_functions_nested::expr_fn::{array_has, make_array};
1649+
use datafusion_physical_expr::expressions as phys_expr;
15851650
use datafusion_physical_expr::planner::logical2physical;
15861651

15871652
#[derive(Debug, Default)]
@@ -3397,6 +3462,75 @@ mod tests {
33973462
// TODO: add test for other case and op
33983463
}
33993464

3465+
#[test]
3466+
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
3467+
struct CustomUnhandledHook;
3468+
3469+
impl UnhandledPredicateHook for CustomUnhandledHook {
3470+
/// This handles an arbitrary case of a column that doesn't exist in the schema
3471+
/// by renaming it to yet another column that doesn't exist in the schema
3472+
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
3473+
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
3474+
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
3475+
}
3476+
}
3477+
3478+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
3479+
let schema_with_b = Schema::new(vec![
3480+
Field::new("a", DataType::Int32, true),
3481+
Field::new("b", DataType::Int32, true),
3482+
]);
3483+
3484+
let transform_expr = |expr| {
3485+
let expr = logical2physical(&expr, &schema_with_b);
3486+
rewrite_predicate_to_statistics_predicate(
3487+
&expr,
3488+
&schema,
3489+
Some(Arc::new(CustomUnhandledHook {})),
3490+
)
3491+
};
3492+
3493+
// transform an arbitrary valid expression that we know is handled
3494+
let known_expression = col("a").eq(lit(ScalarValue::Int32(Some(12))));
3495+
let known_expression_transformed = rewrite_predicate_to_statistics_predicate(
3496+
&logical2physical(&known_expression, &schema),
3497+
&schema,
3498+
None,
3499+
);
3500+
3501+
// an expression referencing an unknown column (that is not in the schema) gets passed to the hook
3502+
let input = col("b").eq(lit(ScalarValue::Int32(Some(12))));
3503+
let expected = logical2physical(&lit(42), &schema);
3504+
let transformed = transform_expr(input.clone());
3505+
assert_eq!(transformed.to_string(), expected.to_string());
3506+
3507+
// more complex case with unknown column
3508+
let input = known_expression.clone().and(input.clone());
3509+
let expected = phys_expr::BinaryExpr::new(
3510+
known_expression_transformed.clone(),
3511+
Operator::And,
3512+
logical2physical(&lit(42), &schema),
3513+
);
3514+
let transformed = transform_expr(input.clone());
3515+
assert_eq!(transformed.to_string(), expected.to_string());
3516+
3517+
// an unknown expression gets passed to the hook
3518+
let input = array_has(make_array(vec![lit(1)]), col("a"));
3519+
let expected = logical2physical(&lit(42), &schema);
3520+
let transformed = transform_expr(input.clone());
3521+
assert_eq!(transformed.to_string(), expected.to_string());
3522+
3523+
// more complex case with unknown expression
3524+
let input = known_expression.and(input);
3525+
let expected = phys_expr::BinaryExpr::new(
3526+
known_expression_transformed.clone(),
3527+
Operator::And,
3528+
logical2physical(&lit(42), &schema),
3529+
);
3530+
let transformed = transform_expr(input.clone());
3531+
assert_eq!(transformed.to_string(), expected.to_string());
3532+
}
3533+
34003534
#[test]
34013535
fn test_rewrite_expr_to_prunable_error() {
34023536
// cast string value to numeric value
@@ -3886,6 +4020,7 @@ mod tests {
38864020
required_columns: &mut RequiredColumns,
38874021
) -> Arc<dyn PhysicalExpr> {
38884022
let expr = logical2physical(expr, schema);
3889-
build_predicate_expression(&expr, schema, required_columns)
4023+
let unhandled_hook = default_unhandled_hook();
4024+
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
38904025
}
38914026
}

0 commit comments

Comments
 (0)