Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PruningPredicate's rewrite public #12850

Merged
merged 5 commits into from
Oct 13, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 187 additions & 23 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
/// or predicates that reference columns that are not in the schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be handled by DataFusion's transformation rules
/// or is referencing a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}

#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl ConstantUnhandledPredicateHook {
fn new(default: Arc<dyn PhysicalExpr>) -> Self {
Self { default }
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

fn default_unhandled_hook() -> Arc<dyn UnhandledPredicateHook> {
Arc::new(ConstantUnhandledPredicateHook::new(Arc::new(
phys_expr::Literal::new(ScalarValue::Boolean(Some(true))),
)))
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +533,16 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = default_unhandled_hook();

// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1311,28 +1348,78 @@ fn build_is_null_column_expr(
/// The maximum number of entries in an `InList` that might be rewritten into
/// an OR chain
const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
/// Rewrite a predicate expression to a pruning predicate expression.
/// See documentation for `PruningPredicate` for more information.
pub struct PredicateRewriter {
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
}

impl Default for PredicateRewriter {
fn default() -> Self {
Self {
unhandled_hook: default_unhandled_hook(),
}
}
}

impl PredicateRewriter {
/// Create a new `PredicateRewriter`
pub fn new() -> Self {
Self::default()
}

/// Set the unhandled hook to be used when a predicate can not be rewritten
pub fn with_unhandled_hook(
self,
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
) -> Self {
Self { unhandled_hook }
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
pub fn rewrite_predicate_to_statistics_predicate(
&self,
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Arc<dyn PhysicalExpr> {
let mut required_columns = RequiredColumns::new();
build_predicate_expression(
expr,
schema,
&mut required_columns,
&self.unhandled_hook,
)
}
}

/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
/// Any predicates that can not be translated will be passed to `unhandled_hook`.
///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE
/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook`
fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1428,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1469,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(
&change_expr,
schema,
required_columns,
unhandled_hook,
);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1488,23 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr =
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr =
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1423,12 +1517,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
}
Err(_) => return unhandled_hook.handle(expr),
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder)
.unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1582,6 +1675,8 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_functions_nested::expr_fn::{array_has, make_array};
use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -3397,6 +3492,74 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let schema_with_b = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]);

let rewriter = PredicateRewriter::new()
.with_unhandled_hook(Arc::new(CustomUnhandledHook {}));

let transform_expr = |expr| {
let expr = logical2physical(&expr, &schema_with_b);
rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema)
};

// transform an arbitrary valid expression that we know is handled
let known_expression = col("a").eq(lit(12));
let known_expression_transformed = PredicateRewriter::new()
.rewrite_predicate_to_statistics_predicate(
&logical2physical(&known_expression, &schema),
&schema,
);

// an expression referencing an unknown column (that is not in the schema) gets passed to the hook
let input = col("b").eq(lit(12));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown column
let input = known_expression.clone().and(input.clone());
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// an unknown expression gets passed to the hook
let input = array_has(make_array(vec![lit(1)]), col("a"));
let expected = logical2physical(&lit(42), &schema);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());

// more complex case with unknown expression
let input = known_expression.and(input);
let expected = phys_expr::BinaryExpr::new(
known_expression_transformed.clone(),
Operator::And,
logical2physical(&lit(42), &schema),
);
let transformed = transform_expr(input.clone());
assert_eq!(transformed.to_string(), expected.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4049,7 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
let unhandled_hook = default_unhandled_hook();
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}