Skip to content

Commit df49f9f

Browse files
authored
fix: preserve null_equals_null flag in eliminate_cross_join rule (#16356)
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
1 parent 9c98b01 commit df49f9f

File tree

1 file changed

+99
-21
lines changed

1 file changed

+99
-21
lines changed

datafusion/optimizer/src/eliminate_cross_join.rs

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
8989
let mut possible_join_keys = JoinKeySet::new();
9090
let mut all_inputs: Vec<LogicalPlan> = vec![];
9191
let mut all_filters: Vec<Expr> = vec![];
92+
let mut null_equals_null = false;
9293

9394
let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
9495
// if input isn't a join that can potentially be rewritten
@@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin {
113114
let Filter {
114115
input, predicate, ..
115116
} = filter;
117+
118+
// Extract null_equals_null setting from the input join
119+
if let LogicalPlan::Join(join) = input.as_ref() {
120+
null_equals_null = join.null_equals_null;
121+
}
122+
116123
flatten_join_inputs(
117124
Arc::unwrap_or_clone(input),
118125
&mut possible_join_keys,
@@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin {
122129

123130
extract_possible_join_keys(&predicate, &mut possible_join_keys);
124131
Some(predicate)
125-
} else if matches!(
126-
plan,
127-
LogicalPlan::Join(Join {
128-
join_type: JoinType::Inner,
129-
..
130-
})
131-
) {
132-
if !can_flatten_join_inputs(&plan) {
133-
return Ok(Transformed::no(plan));
134-
}
135-
flatten_join_inputs(
136-
plan,
137-
&mut possible_join_keys,
138-
&mut all_inputs,
139-
&mut all_filters,
140-
)?;
141-
None
142132
} else {
143-
// recursively try to rewrite children
144-
return rewrite_children(self, plan, config);
133+
match plan {
134+
LogicalPlan::Join(Join {
135+
join_type: JoinType::Inner,
136+
null_equals_null: original_null_equals_null,
137+
..
138+
}) => {
139+
if !can_flatten_join_inputs(&plan) {
140+
return Ok(Transformed::no(plan));
141+
}
142+
flatten_join_inputs(
143+
plan,
144+
&mut possible_join_keys,
145+
&mut all_inputs,
146+
&mut all_filters,
147+
)?;
148+
null_equals_null = original_null_equals_null;
149+
None
150+
}
151+
_ => {
152+
// recursively try to rewrite children
153+
return rewrite_children(self, plan, config);
154+
}
155+
}
145156
};
146157

147158
// Join keys are handled locally:
@@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
153164
&mut all_inputs,
154165
&possible_join_keys,
155166
&mut all_join_keys,
167+
null_equals_null,
156168
)?;
157169
}
158170

@@ -290,6 +302,7 @@ fn find_inner_join(
290302
rights: &mut Vec<LogicalPlan>,
291303
possible_join_keys: &JoinKeySet,
292304
all_join_keys: &mut JoinKeySet,
305+
null_equals_null: bool,
293306
) -> Result<LogicalPlan> {
294307
for (i, right_input) in rights.iter().enumerate() {
295308
let mut join_keys = vec![];
@@ -328,7 +341,7 @@ fn find_inner_join(
328341
on: join_keys,
329342
filter: None,
330343
schema: join_schema,
331-
null_equals_null: false,
344+
null_equals_null,
332345
}));
333346
}
334347
}
@@ -350,7 +363,7 @@ fn find_inner_join(
350363
filter: None,
351364
join_type: JoinType::Inner,
352365
join_constraint: JoinConstraint::On,
353-
null_equals_null: false,
366+
null_equals_null,
354367
}))
355368
}
356369

@@ -1333,4 +1346,69 @@ mod tests {
13331346
"
13341347
)
13351348
}
1349+
1350+
#[test]
1351+
fn preserve_null_equals_null_setting() -> Result<()> {
1352+
let t1 = test_table_scan_with_name("t1")?;
1353+
let t2 = test_table_scan_with_name("t2")?;
1354+
1355+
// Create an inner join with null_equals_null: true
1356+
let join_schema = Arc::new(build_join_schema(
1357+
t1.schema(),
1358+
t2.schema(),
1359+
&JoinType::Inner,
1360+
)?);
1361+
1362+
let inner_join = LogicalPlan::Join(Join {
1363+
left: Arc::new(t1),
1364+
right: Arc::new(t2),
1365+
join_type: JoinType::Inner,
1366+
join_constraint: JoinConstraint::On,
1367+
on: vec![],
1368+
filter: None,
1369+
schema: join_schema,
1370+
null_equals_null: true, // Set to true to test preservation
1371+
});
1372+
1373+
// Apply filter that can create join conditions
1374+
let plan = LogicalPlanBuilder::from(inner_join)
1375+
.filter(binary_expr(
1376+
col("t1.a").eq(col("t2.a")),
1377+
And,
1378+
col("t2.c").lt(lit(20u32)),
1379+
))?
1380+
.build()?;
1381+
1382+
let rule = EliminateCrossJoin::new();
1383+
let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;
1384+
1385+
// Verify that null_equals_null is preserved in the optimized plan
1386+
fn check_null_equals_null_preserved(plan: &LogicalPlan) -> bool {
1387+
match plan {
1388+
LogicalPlan::Join(join) => {
1389+
// All joins in the optimized plan should preserve null_equals_null: true
1390+
if !join.null_equals_null {
1391+
return false;
1392+
}
1393+
// Recursively check child plans
1394+
plan.inputs()
1395+
.iter()
1396+
.all(|input| check_null_equals_null_preserved(input))
1397+
}
1398+
_ => {
1399+
// Recursively check child plans for non-join nodes
1400+
plan.inputs()
1401+
.iter()
1402+
.all(|input| check_null_equals_null_preserved(input))
1403+
}
1404+
}
1405+
}
1406+
1407+
assert!(
1408+
check_null_equals_null_preserved(&optimized_plan),
1409+
"null_equals_null setting should be preserved after optimization"
1410+
);
1411+
1412+
Ok(())
1413+
}
13361414
}

0 commit comments

Comments
 (0)