Skip to content

Commit d0aac59

Browse files
committed
Support inferring new predicates to push down
1 parent af99b54 commit d0aac59

File tree

4 files changed

+124
-7
lines changed

4 files changed

+124
-7
lines changed

datafusion/expr/src/expr_rewriter/mod.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,25 @@ pub fn normalize_sorts(
131131
}
132132

133133
/// Recursively replace all [`Column`] expressions in a given expression tree with
134-
/// `Column` expressions provided by the hash map argument.
135-
pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
134+
/// the expressions provided by the hash map argument.
135+
///
136+
/// # Arguments
137+
/// * `expr` - The expression to transform
138+
/// * `replace_map` - A mapping from Column to replacement expression
139+
/// * `to_expr` - A function that converts the replacement value to an Expr
140+
pub fn replace_col<V, F>(
141+
expr: Expr,
142+
replace_map: &HashMap<&Column, V>,
143+
to_expr: F,
144+
) -> Result<Expr>
145+
where
146+
F: Fn(&V) -> Expr,
147+
{
136148
expr.transform(|expr| {
137149
Ok({
138150
if let Expr::Column(c) = &expr {
139151
match replace_map.get(c) {
140-
Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
152+
Some(replacement) => Transformed::yes(to_expr(replacement)),
141153
None => Transformed::no(expr),
142154
}
143155
} else {

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,10 @@ impl InferredPredicates {
631631
Ok(true)
632632
)
633633
{
634-
self.predicates.push(replace_col(predicate, replace_map)?);
634+
self.predicates
635+
.push(replace_col(predicate, replace_map, |col| {
636+
Expr::Column((*col).clone())
637+
})?);
635638
}
636639

637640
Ok(())
@@ -784,13 +787,14 @@ impl OptimizerRule for PushDownFilter {
784787

785788
// remove duplicated filters
786789
let child_predicates = split_conjunction_owned(child_filter.predicate);
787-
let new_predicates = parents_predicates
790+
let mut new_predicates = parents_predicates
788791
.into_iter()
789792
.chain(child_predicates)
790793
// use IndexSet to remove dupes while preserving predicate order
791794
.collect::<IndexSet<_>>()
792795
.into_iter()
793796
.collect::<Vec<_>>();
797+
new_predicates = infer_predicates_from_equalities(new_predicates)?;
794798

795799
let Some(new_predicate) = conjunction(new_predicates) else {
796800
return plan_err!("at least one expression exists");
@@ -1382,6 +1386,73 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
13821386
is_contain
13831387
}
13841388

1389+
/// Infers new predicates by substituting equalities.
1390+
/// For example, with predicates `t2.b = 3` and `t1.b > t2.b`,
1391+
/// we can infer `t1.b > 3`.
1392+
fn infer_predicates_from_equalities(predicates: Vec<Expr>) -> Result<Vec<Expr>> {
1393+
// Map from column names to their literal values (from equality predicates)
1394+
let mut equality_map: HashMap<Column, Expr> =
1395+
HashMap::with_capacity(predicates.len());
1396+
let mut final_predicates = Vec::with_capacity(predicates.len());
1397+
// First pass: collect column=literal equalities
1398+
for predicate in predicates.iter() {
1399+
if let Expr::BinaryExpr(BinaryExpr {
1400+
left,
1401+
op: Operator::Eq,
1402+
right,
1403+
}) = predicate
1404+
{
1405+
if let Expr::Column(col) = left.as_ref() {
1406+
// Only add to map if right side is a literal
1407+
if matches!(right.as_ref(), Expr::Literal(_)) {
1408+
equality_map.insert(col.clone(), *right.clone());
1409+
final_predicates.push(predicate.clone());
1410+
}
1411+
} else if let Expr::Column(col) = right.as_ref() {
1412+
// Only add to map if left side is a literal
1413+
if matches!(left.as_ref(), Expr::Literal(_)) {
1414+
equality_map.insert(col.clone(), *right.clone());
1415+
final_predicates.push(predicate.clone());
1416+
}
1417+
}
1418+
}
1419+
}
1420+
1421+
// If no equality mappings found, nothing to infer
1422+
if equality_map.is_empty() {
1423+
return Ok(predicates);
1424+
}
1425+
1426+
// Second pass: apply substitutions to create new predicates
1427+
for predicate in predicates {
1428+
// Skip equality predicates we already used for mapping
1429+
if final_predicates.contains(&predicate) {
1430+
continue;
1431+
}
1432+
1433+
// Try to replace columns with their literal values
1434+
let mut columns_in_expr = HashSet::new();
1435+
expr_to_columns(&predicate, &mut columns_in_expr)?;
1436+
1437+
// Create a combined replacement map for all columns in this predicate
1438+
let replace_map: HashMap<_, _> = columns_in_expr
1439+
.iter()
1440+
.filter_map(|col| equality_map.get(col).map(|lit| (col, lit)))
1441+
.collect();
1442+
1443+
if replace_map.is_empty() {
1444+
final_predicates.push(predicate);
1445+
continue;
1446+
}
1447+
// Apply all substitutions at once to get the fully substituted predicate
1448+
let new_pred = replace_col(predicate, &replace_map, |e| (*e).clone())?;
1449+
1450+
final_predicates.push(new_pred);
1451+
}
1452+
1453+
Ok(final_predicates)
1454+
}
1455+
13851456
#[cfg(test)]
13861457
mod tests {
13871458
use std::any::Any;

datafusion/optimizer/src/utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub(crate) fn replace_qualified_name(
5959
let replace_map: HashMap<&Column, &Column> =
6060
cols.iter().zip(alias_cols.iter()).collect();
6161

62-
replace_col(expr, &replace_map)
62+
replace_col(expr, &replace_map, |col| Expr::Column((*col).clone()))
6363
}
6464

6565
/// Log the plan in debug/tracing mode after some part of the optimizer runs
@@ -136,7 +136,9 @@ fn evaluate_expr_with_null_column<'a>(
136136
.map(|column| (column, &null_column))
137137
.collect::<HashMap<_, _>>();
138138

139-
let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
139+
let replaced_predicate = replace_col(predicate, &join_cols_to_replace, |col| {
140+
Expr::Column((*col).clone())
141+
})?;
140142
let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
141143
create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
142144
.evaluate(&input_batch)

datafusion/sqllogictest/test_files/push_down_filter.slt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,35 @@ logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8
259259

260260
statement ok
261261
drop table t;
262+
263+
statement ok
264+
create table t1(a int, b int) as values(1, 2), (2, 3), (3 ,4);
265+
266+
statement ok
267+
create table t2(a int, b int) as values (1, 2), (2, 4), (4, 5);
268+
269+
query TT
270+
explain select
271+
*
272+
from
273+
t1
274+
join t2 on t1.a = t2.a
275+
and t1.b between t2.b
276+
and t2.b + 2
277+
where
278+
t2.b = 3
279+
----
280+
logical_plan
281+
01)Inner Join: t1.a = t2.a
282+
02)--Projection: t1.a, t1.b
283+
03)----Filter: __common_expr_4 >= Int64(3) AND __common_expr_4 <= Int64(5)
284+
04)------Projection: CAST(t1.b AS Int64) AS __common_expr_4, t1.a, t1.b
285+
05)--------TableScan: t1 projection=[a, b]
286+
06)--Filter: t2.b = Int32(3)
287+
07)----TableScan: t2 projection=[a, b]
288+
289+
statement ok
290+
drop table t1;
291+
292+
statement ok
293+
drop table t2;

0 commit comments

Comments
 (0)