@@ -478,6 +478,37 @@ pub struct PruningPredicate {
478
478
literal_guarantees : Vec < LiteralGuarantee > ,
479
479
}
480
480
481
+ /// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
482
+ /// or predicates that reference columns that are not in the schema.
483
+ pub trait UnhandledPredicateHook {
484
+ /// Called when a predicate can not be handled by DataFusion's transformation rules
485
+ /// or is referencing a column that is not in the schema.
486
+ fn handle ( & self , expr : & Arc < dyn PhysicalExpr > ) -> Arc < dyn PhysicalExpr > ;
487
+ }
488
+
489
+ #[ derive( Debug , Clone ) ]
490
+ struct ConstantUnhandledPredicateHook {
491
+ default : Arc < dyn PhysicalExpr > ,
492
+ }
493
+
494
+ impl ConstantUnhandledPredicateHook {
495
+ fn new ( default : Arc < dyn PhysicalExpr > ) -> Self {
496
+ Self { default }
497
+ }
498
+ }
499
+
500
+ impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
501
+ fn handle ( & self , _expr : & Arc < dyn PhysicalExpr > ) -> Arc < dyn PhysicalExpr > {
502
+ self . default . clone ( )
503
+ }
504
+ }
505
+
506
+ fn default_unhandled_hook ( ) -> Arc < dyn UnhandledPredicateHook > {
507
+ Arc :: new ( ConstantUnhandledPredicateHook :: new ( Arc :: new (
508
+ phys_expr:: Literal :: new ( ScalarValue :: Boolean ( Some ( true ) ) ) ,
509
+ ) ) )
510
+ }
511
+
481
512
impl PruningPredicate {
482
513
/// Try to create a new instance of [`PruningPredicate`]
483
514
///
@@ -502,10 +533,16 @@ impl PruningPredicate {
502
533
/// See the struct level documentation on [`PruningPredicate`] for more
503
534
/// details.
504
535
pub fn try_new ( expr : Arc < dyn PhysicalExpr > , schema : SchemaRef ) -> Result < Self > {
536
+ let unhandled_hook = default_unhandled_hook ( ) ;
537
+
505
538
// build predicate expression once
506
539
let mut required_columns = RequiredColumns :: new ( ) ;
507
- let predicate_expr =
508
- build_predicate_expression ( & expr, schema. as_ref ( ) , & mut required_columns) ;
540
+ let predicate_expr = build_predicate_expression (
541
+ & expr,
542
+ schema. as_ref ( ) ,
543
+ & mut required_columns,
544
+ & unhandled_hook,
545
+ ) ;
509
546
510
547
let literal_guarantees = LiteralGuarantee :: analyze ( & expr) ;
511
548
@@ -1316,23 +1353,43 @@ const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
1316
1353
/// expression that will evaluate to FALSE if it can be determined no
1317
1354
/// rows between the min/max values could pass the predicates.
1318
1355
///
1356
+ /// Any predicates that can not be translated will be passed to `unhandled_hook`.
1357
+ ///
1319
1358
/// Returns the pruning predicate as an [`PhysicalExpr`]
1320
1359
///
1321
- /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
1360
+ /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
1361
+ pub fn rewrite_predicate_to_statistics_predicate (
1362
+ expr : & Arc < dyn PhysicalExpr > ,
1363
+ schema : & Schema ,
1364
+ unhandled_hook : Option < Arc < dyn UnhandledPredicateHook > > ,
1365
+ ) -> Arc < dyn PhysicalExpr > {
1366
+ let unhandled_hook = unhandled_hook. unwrap_or ( default_unhandled_hook ( ) ) ;
1367
+
1368
+ let mut required_columns = RequiredColumns :: new ( ) ;
1369
+
1370
+ build_predicate_expression ( expr, schema, & mut required_columns, & unhandled_hook)
1371
+ }
1372
+
1373
+ /// Translate logical filter expression into pruning predicate
1374
+ /// expression that will evaluate to FALSE if it can be determined no
1375
+ /// rows between the min/max values could pass the predicates.
1376
+ ///
1377
+ /// Any predicates that can not be translated will be passed to `unhandled_hook`.
1378
+ ///
1379
+ /// Returns the pruning predicate as an [`PhysicalExpr`]
1380
+ ///
1381
+ /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
1322
1382
fn build_predicate_expression (
1323
1383
expr : & Arc < dyn PhysicalExpr > ,
1324
1384
schema : & Schema ,
1325
1385
required_columns : & mut RequiredColumns ,
1386
+ unhandled_hook : & Arc < dyn UnhandledPredicateHook > ,
1326
1387
) -> Arc < dyn PhysicalExpr > {
1327
- // Returned for unsupported expressions. Such expressions are
1328
- // converted to TRUE.
1329
- let unhandled = Arc :: new ( phys_expr:: Literal :: new ( ScalarValue :: Boolean ( Some ( true ) ) ) ) ;
1330
-
1331
1388
// predicate expression can only be a binary expression
1332
1389
let expr_any = expr. as_any ( ) ;
1333
1390
if let Some ( is_null) = expr_any. downcast_ref :: < phys_expr:: IsNullExpr > ( ) {
1334
1391
return build_is_null_column_expr ( is_null. arg ( ) , schema, required_columns, false )
1335
- . unwrap_or ( unhandled ) ;
1392
+ . unwrap_or_else ( || unhandled_hook . handle ( expr ) ) ;
1336
1393
}
1337
1394
if let Some ( is_not_null) = expr_any. downcast_ref :: < phys_expr:: IsNotNullExpr > ( ) {
1338
1395
return build_is_null_column_expr (
@@ -1341,19 +1398,19 @@ fn build_predicate_expression(
1341
1398
required_columns,
1342
1399
true ,
1343
1400
)
1344
- . unwrap_or ( unhandled ) ;
1401
+ . unwrap_or_else ( || unhandled_hook . handle ( expr ) ) ;
1345
1402
}
1346
1403
if let Some ( col) = expr_any. downcast_ref :: < phys_expr:: Column > ( ) {
1347
1404
return build_single_column_expr ( col, schema, required_columns, false )
1348
- . unwrap_or ( unhandled ) ;
1405
+ . unwrap_or_else ( || unhandled_hook . handle ( expr ) ) ;
1349
1406
}
1350
1407
if let Some ( not) = expr_any. downcast_ref :: < phys_expr:: NotExpr > ( ) {
1351
1408
// match !col (don't do so recursively)
1352
1409
if let Some ( col) = not. arg ( ) . as_any ( ) . downcast_ref :: < phys_expr:: Column > ( ) {
1353
1410
return build_single_column_expr ( col, schema, required_columns, true )
1354
- . unwrap_or ( unhandled ) ;
1411
+ . unwrap_or_else ( || unhandled_hook . handle ( expr ) ) ;
1355
1412
} else {
1356
- return unhandled ;
1413
+ return unhandled_hook . handle ( expr ) ;
1357
1414
}
1358
1415
}
1359
1416
if let Some ( in_list) = expr_any. downcast_ref :: < phys_expr:: InListExpr > ( ) {
@@ -1382,9 +1439,14 @@ fn build_predicate_expression(
1382
1439
} )
1383
1440
. reduce ( |a, b| Arc :: new ( phys_expr:: BinaryExpr :: new ( a, re_op, b) ) as _ )
1384
1441
. unwrap ( ) ;
1385
- return build_predicate_expression ( & change_expr, schema, required_columns) ;
1442
+ return build_predicate_expression (
1443
+ & change_expr,
1444
+ schema,
1445
+ required_columns,
1446
+ unhandled_hook,
1447
+ ) ;
1386
1448
} else {
1387
- return unhandled ;
1449
+ return unhandled_hook . handle ( expr ) ;
1388
1450
}
1389
1451
}
1390
1452
@@ -1396,21 +1458,23 @@ fn build_predicate_expression(
1396
1458
bin_expr. right ( ) . clone ( ) ,
1397
1459
)
1398
1460
} else {
1399
- return unhandled ;
1461
+ return unhandled_hook . handle ( expr ) ;
1400
1462
}
1401
1463
} ;
1402
1464
1403
1465
if op == Operator :: And || op == Operator :: Or {
1404
- let left_expr = build_predicate_expression ( & left, schema, required_columns) ;
1405
- let right_expr = build_predicate_expression ( & right, schema, required_columns) ;
1466
+ let left_expr =
1467
+ build_predicate_expression ( & left, schema, required_columns, unhandled_hook) ;
1468
+ let right_expr =
1469
+ build_predicate_expression ( & right, schema, required_columns, unhandled_hook) ;
1406
1470
// simplify boolean expression if applicable
1407
1471
let expr = match ( & left_expr, op, & right_expr) {
1408
1472
( left, Operator :: And , _) if is_always_true ( left) => right_expr,
1409
1473
( _, Operator :: And , right) if is_always_true ( right) => left_expr,
1410
1474
( left, Operator :: Or , right)
1411
1475
if is_always_true ( left) || is_always_true ( right) =>
1412
1476
{
1413
- unhandled
1477
+ Arc :: new ( phys_expr :: Literal :: new ( ScalarValue :: Boolean ( Some ( true ) ) ) )
1414
1478
}
1415
1479
_ => Arc :: new ( phys_expr:: BinaryExpr :: new ( left_expr, op, right_expr) ) ,
1416
1480
} ;
@@ -1423,12 +1487,11 @@ fn build_predicate_expression(
1423
1487
Ok ( builder) => builder,
1424
1488
// allow partial failure in predicate expression generation
1425
1489
// this can still produce a useful predicate when multiple conditions are joined using AND
1426
- Err ( _) => {
1427
- return unhandled;
1428
- }
1490
+ Err ( _) => return unhandled_hook. handle ( expr) ,
1429
1491
} ;
1430
1492
1431
- build_statistics_expr ( & mut expr_builder) . unwrap_or ( unhandled)
1493
+ build_statistics_expr ( & mut expr_builder)
1494
+ . unwrap_or_else ( |_| unhandled_hook. handle ( expr) )
1432
1495
}
1433
1496
1434
1497
fn build_statistics_expr (
@@ -1582,6 +1645,8 @@ mod tests {
1582
1645
use arrow_array:: UInt64Array ;
1583
1646
use datafusion_expr:: expr:: InList ;
1584
1647
use datafusion_expr:: { cast, is_null, try_cast, Expr } ;
1648
+ use datafusion_functions_nested:: expr_fn:: { array_has, make_array} ;
1649
+ use datafusion_physical_expr:: expressions as phys_expr;
1585
1650
use datafusion_physical_expr:: planner:: logical2physical;
1586
1651
1587
1652
#[ derive( Debug , Default ) ]
@@ -3397,6 +3462,75 @@ mod tests {
3397
3462
// TODO: add test for other case and op
3398
3463
}
3399
3464
3465
+ #[ test]
3466
+ fn test_rewrite_expr_to_prunable_custom_unhandled_hook ( ) {
3467
+ struct CustomUnhandledHook ;
3468
+
3469
+ impl UnhandledPredicateHook for CustomUnhandledHook {
3470
+ /// This handles an arbitrary case of a column that doesn't exist in the schema
3471
+ /// by renaming it to yet another column that doesn't exist in the schema
3472
+ /// (the transformation is arbitrary, the point is that it can do whatever it wants)
3473
+ fn handle ( & self , _expr : & Arc < dyn PhysicalExpr > ) -> Arc < dyn PhysicalExpr > {
3474
+ Arc :: new ( phys_expr:: Literal :: new ( ScalarValue :: Int32 ( Some ( 42 ) ) ) )
3475
+ }
3476
+ }
3477
+
3478
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , true ) ] ) ;
3479
+ let schema_with_b = Schema :: new ( vec ! [
3480
+ Field :: new( "a" , DataType :: Int32 , true ) ,
3481
+ Field :: new( "b" , DataType :: Int32 , true ) ,
3482
+ ] ) ;
3483
+
3484
+ let transform_expr = |expr| {
3485
+ let expr = logical2physical ( & expr, & schema_with_b) ;
3486
+ rewrite_predicate_to_statistics_predicate (
3487
+ & expr,
3488
+ & schema,
3489
+ Some ( Arc :: new ( CustomUnhandledHook { } ) ) ,
3490
+ )
3491
+ } ;
3492
+
3493
+ // transform an arbitrary valid expression that we know is handled
3494
+ let known_expression = col ( "a" ) . eq ( lit ( ScalarValue :: Int32 ( Some ( 12 ) ) ) ) ;
3495
+ let known_expression_transformed = rewrite_predicate_to_statistics_predicate (
3496
+ & logical2physical ( & known_expression, & schema) ,
3497
+ & schema,
3498
+ None ,
3499
+ ) ;
3500
+
3501
+ // an expression referencing an unknown column (that is not in the schema) gets passed to the hook
3502
+ let input = col ( "b" ) . eq ( lit ( ScalarValue :: Int32 ( Some ( 12 ) ) ) ) ;
3503
+ let expected = logical2physical ( & lit ( 42 ) , & schema) ;
3504
+ let transformed = transform_expr ( input. clone ( ) ) ;
3505
+ assert_eq ! ( transformed. to_string( ) , expected. to_string( ) ) ;
3506
+
3507
+ // more complex case with unknown column
3508
+ let input = known_expression. clone ( ) . and ( input. clone ( ) ) ;
3509
+ let expected = phys_expr:: BinaryExpr :: new (
3510
+ known_expression_transformed. clone ( ) ,
3511
+ Operator :: And ,
3512
+ logical2physical ( & lit ( 42 ) , & schema) ,
3513
+ ) ;
3514
+ let transformed = transform_expr ( input. clone ( ) ) ;
3515
+ assert_eq ! ( transformed. to_string( ) , expected. to_string( ) ) ;
3516
+
3517
+ // an unknown expression gets passed to the hook
3518
+ let input = array_has ( make_array ( vec ! [ lit( 1 ) ] ) , col ( "a" ) ) ;
3519
+ let expected = logical2physical ( & lit ( 42 ) , & schema) ;
3520
+ let transformed = transform_expr ( input. clone ( ) ) ;
3521
+ assert_eq ! ( transformed. to_string( ) , expected. to_string( ) ) ;
3522
+
3523
+ // more complex case with unknown expression
3524
+ let input = known_expression. and ( input) ;
3525
+ let expected = phys_expr:: BinaryExpr :: new (
3526
+ known_expression_transformed. clone ( ) ,
3527
+ Operator :: And ,
3528
+ logical2physical ( & lit ( 42 ) , & schema) ,
3529
+ ) ;
3530
+ let transformed = transform_expr ( input. clone ( ) ) ;
3531
+ assert_eq ! ( transformed. to_string( ) , expected. to_string( ) ) ;
3532
+ }
3533
+
3400
3534
#[ test]
3401
3535
fn test_rewrite_expr_to_prunable_error ( ) {
3402
3536
// cast string value to numeric value
@@ -3886,6 +4020,7 @@ mod tests {
3886
4020
required_columns : & mut RequiredColumns ,
3887
4021
) -> Arc < dyn PhysicalExpr > {
3888
4022
let expr = logical2physical ( expr, schema) ;
3889
- build_predicate_expression ( & expr, schema, required_columns)
4023
+ let unhandled_hook = default_unhandled_hook ( ) ;
4024
+ build_predicate_expression ( & expr, schema, required_columns, & unhandled_hook)
3890
4025
}
3891
4026
}
0 commit comments