Skip to content

fix: preserve null_equals_null flag in eliminate_cross_join rule #16356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 99 additions & 21 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<LogicalPlan> = vec![];
let mut all_filters: Vec<Expr> = vec![];
let mut null_equals_null = false;

let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
// if input isn't a join that can potentially be rewritten
Expand All @@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin {
let Filter {
input, predicate, ..
} = filter;

// Extract null_equals_null setting from the input join
if let LogicalPlan::Join(join) = input.as_ref() {
null_equals_null = join.null_equals_null;
}

flatten_join_inputs(
Arc::unwrap_or_clone(input),
&mut possible_join_keys,
Expand All @@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin {

extract_possible_join_keys(&predicate, &mut possible_join_keys);
Some(predicate)
} else if matches!(
plan,
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
..
})
) {
if !can_flatten_join_inputs(&plan) {
return Ok(Transformed::no(plan));
}
flatten_join_inputs(
plan,
&mut possible_join_keys,
&mut all_inputs,
&mut all_filters,
)?;
None
} else {
// recursively try to rewrite children
return rewrite_children(self, plan, config);
match plan {
LogicalPlan::Join(Join {
join_type: JoinType::Inner,
null_equals_null: original_null_equals_null,
..
}) => {
if !can_flatten_join_inputs(&plan) {
return Ok(Transformed::no(plan));
}
flatten_join_inputs(
plan,
&mut possible_join_keys,
&mut all_inputs,
&mut all_filters,
)?;
null_equals_null = original_null_equals_null;
None
}
_ => {
// recursively try to rewrite children
return rewrite_children(self, plan, config);
}
}
};

// Join keys are handled locally:
Expand All @@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
&mut all_inputs,
&possible_join_keys,
&mut all_join_keys,
null_equals_null,
)?;
}

Expand Down Expand Up @@ -290,6 +302,7 @@ fn find_inner_join(
rights: &mut Vec<LogicalPlan>,
possible_join_keys: &JoinKeySet,
all_join_keys: &mut JoinKeySet,
null_equals_null: bool,
) -> Result<LogicalPlan> {
for (i, right_input) in rights.iter().enumerate() {
let mut join_keys = vec![];
Expand Down Expand Up @@ -328,7 +341,7 @@ fn find_inner_join(
on: join_keys,
filter: None,
schema: join_schema,
null_equals_null: false,
null_equals_null,
}));
}
}
Expand All @@ -350,7 +363,7 @@ fn find_inner_join(
filter: None,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
null_equals_null,
}))
}

Expand Down Expand Up @@ -1333,4 +1346,69 @@ mod tests {
"
)
}

#[test]
fn preserve_null_equals_null_setting() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;

// Create an inner join with null_equals_null: true
let join_schema = Arc::new(build_join_schema(
t1.schema(),
t2.schema(),
&JoinType::Inner,
)?);

let inner_join = LogicalPlan::Join(Join {
left: Arc::new(t1),
right: Arc::new(t2),
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
on: vec![],
filter: None,
schema: join_schema,
null_equals_null: true, // Set to true to test preservation
});

// Apply filter that can create join conditions
let plan = LogicalPlanBuilder::from(inner_join)
.filter(binary_expr(
col("t1.a").eq(col("t2.a")),
And,
col("t2.c").lt(lit(20u32)),
))?
.build()?;

let rule = EliminateCrossJoin::new();
let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data;

// Verify that null_equals_null is preserved in the optimized plan
fn check_null_equals_null_preserved(plan: &LogicalPlan) -> bool {
match plan {
LogicalPlan::Join(join) => {
// All joins in the optimized plan should preserve null_equals_null: true
if !join.null_equals_null {
return false;
}
// Recursively check child plans
plan.inputs()
.iter()
.all(|input| check_null_equals_null_preserved(input))
}
_ => {
// Recursively check child plans for non-join nodes
plan.inputs()
.iter()
.all(|input| check_null_equals_null_preserved(input))
}
}
}

assert!(
check_null_equals_null_preserved(&optimized_plan),
"null_equals_null setting should be preserved after optimization"
);

Ok(())
}
}