Skip to content

Support inferring new predicates to push down #15906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,25 @@ pub fn normalize_sorts(
}

/// Recursively replace all [`Column`] expressions in a given expression tree with
/// `Column` expressions provided by the hash map argument.
pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't wanna write a similar method for the PR, so made the method generic

pub fn replace_col_with_expr(
    expr: Expr,
    replace_map: &HashMap<Column, &Expr>,
) -> Result<Expr> {
    expr.transform(|expr| {
        Ok({
            if let Expr::Column(c) = &expr {
                match replace_map.get(c) {
                    Some(new_expr) => Transformed::yes((**new_expr).to_owned()),
                    None => Transformed::no(expr),
                }
            } else {
                Transformed::no(expr)
            }
        })
    })
    .data()
}

/// the expressions provided by the hash map argument.
///
/// # Arguments
/// * `expr` - The expression to transform
/// * `replace_map` - A mapping from Column to replacement expression
/// * `to_expr` - A function that converts the replacement value to an Expr
pub fn replace_col<V, F>(
expr: Expr,
replace_map: &HashMap<&Column, V>,
to_expr: F,
) -> Result<Expr>
where
F: Fn(&V) -> Expr,
{
expr.transform(|expr| {
Ok({
if let Expr::Column(c) = &expr {
match replace_map.get(c) {
Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
Some(replacement) => Transformed::yes(to_expr(replacement)),
None => Transformed::no(expr),
}
} else {
Expand Down
75 changes: 73 additions & 2 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,10 @@ impl InferredPredicates {
Ok(true)
)
{
self.predicates.push(replace_col(predicate, replace_map)?);
self.predicates
.push(replace_col(predicate, replace_map, |col| {
Expr::Column((*col).clone())
})?);
}

Ok(())
Expand Down Expand Up @@ -784,13 +787,14 @@ impl OptimizerRule for PushDownFilter {

// remove duplicated filters
let child_predicates = split_conjunction_owned(child_filter.predicate);
let new_predicates = parents_predicates
let mut new_predicates = parents_predicates
.into_iter()
.chain(child_predicates)
// use IndexSet to remove dupes while preserving predicate order
.collect::<IndexSet<_>>()
.into_iter()
.collect::<Vec<_>>();
new_predicates = infer_predicates_from_equalities(new_predicates)?;

let Some(new_predicate) = conjunction(new_predicates) else {
return plan_err!("at least one expression exists");
Expand Down Expand Up @@ -1382,6 +1386,73 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
is_contain
}

/// Infers new predicates by substituting equalities.
/// For example, with predicates `t2.b = 3` and `t1.b > t2.b`,
/// we can infer `t1.b > 3`.
fn infer_predicates_from_equalities(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, we can move the code into a dedicated optimizer rule, such as InferPredicates

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be a special case of the range analysis code in

https://docs.rs/datafusion/latest/datafusion/physical_expr/intervals/cp_solver/index.html

In other words, instead of this special case maybe we could use the cp_solver to create a more general framework for introducing inferred predicates 🤔

Now that we have predicate pushdown for ExecutionPlans maybe it is more realistic to do this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check the cp_solver (didn't notice the part of code before)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. I can help if you need some directions or have any confusion

// Map from column names to their literal values (from equality predicates)
let mut equality_map: HashMap<Column, Expr> =
HashMap::with_capacity(predicates.len());
let mut final_predicates = Vec::with_capacity(predicates.len());
// First pass: collect column=literal equalities
for predicate in predicates.iter() {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) = predicate
{
if let Expr::Column(col) = left.as_ref() {
// Only add to map if right side is a literal
if matches!(right.as_ref(), Expr::Literal(_)) {
equality_map.insert(col.clone(), *right.clone());
final_predicates.push(predicate.clone());
}
} else if let Expr::Column(col) = right.as_ref() {
// Only add to map if left side is a literal
if matches!(left.as_ref(), Expr::Literal(_)) {
equality_map.insert(col.clone(), *right.clone());
final_predicates.push(predicate.clone());
}
}
}
}

// If no equality mappings found, nothing to infer
if equality_map.is_empty() {
return Ok(predicates);
}

// Second pass: apply substitutions to create new predicates
for predicate in predicates {
// Skip equality predicates we already used for mapping
if final_predicates.contains(&predicate) {
continue;
}

// Try to replace columns with their literal values
let mut columns_in_expr = HashSet::new();
expr_to_columns(&predicate, &mut columns_in_expr)?;

// Create a combined replacement map for all columns in this predicate
let replace_map: HashMap<_, _> = columns_in_expr
.iter()
.filter_map(|col| equality_map.get(col).map(|lit| (col, lit)))
.collect();

if replace_map.is_empty() {
final_predicates.push(predicate);
continue;
}
// Apply all substitutions at once to get the fully substituted predicate
let new_pred = replace_col(predicate, &replace_map, |e| (*e).clone())?;

final_predicates.push(new_pred);
}

Ok(final_predicates)
}

#[cfg(test)]
mod tests {
use std::any::Any;
Expand Down
6 changes: 4 additions & 2 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub(crate) fn replace_qualified_name(
let replace_map: HashMap<&Column, &Column> =
cols.iter().zip(alias_cols.iter()).collect();

replace_col(expr, &replace_map)
replace_col(expr, &replace_map, |col| Expr::Column((*col).clone()))
}

/// Log the plan in debug/tracing mode after some part of the optimizer runs
Expand Down Expand Up @@ -136,7 +136,9 @@ fn evaluate_expr_with_null_column<'a>(
.map(|column| (column, &null_column))
.collect::<HashMap<_, _>>();

let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
let replaced_predicate = replace_col(predicate, &join_cols_to_replace, |col| {
Expr::Column((*col).clone())
})?;
let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
.evaluate(&input_batch)
Expand Down
32 changes: 32 additions & 0 deletions datafusion/sqllogictest/test_files/push_down_filter.slt
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,35 @@ logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8

statement ok
drop table t;

statement ok
create table t1(a int, b int) as values(1, 2), (2, 3), (3 ,4);

statement ok
create table t2(a int, b int) as values (1, 2), (2, 4), (4, 5);

query TT
explain select
*
from
t1
join t2 on t1.a = t2.a
and t1.b between t2.b
and t2.b + 2
where
t2.b = 3
----
logical_plan
01)Inner Join: t1.a = t2.a
02)--Projection: t1.a, t1.b
03)----Filter: __common_expr_4 >= Int64(3) AND __common_expr_4 <= Int64(5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inferred predicate, which can be pushed down to the t1 scan.

04)------Projection: CAST(t1.b AS Int64) AS __common_expr_4, t1.a, t1.b
05)--------TableScan: t1 projection=[a, b]
06)--Filter: t2.b = Int32(3)
07)----TableScan: t2 projection=[a, b]

statement ok
drop table t1;

statement ok
drop table t2;
Loading