@@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
89
89
let mut possible_join_keys = JoinKeySet :: new ( ) ;
90
90
let mut all_inputs: Vec < LogicalPlan > = vec ! [ ] ;
91
91
let mut all_filters: Vec < Expr > = vec ! [ ] ;
92
+ let mut null_equals_null = false ;
92
93
93
94
let parent_predicate = if let LogicalPlan :: Filter ( filter) = plan {
94
95
// if input isn't a join that can potentially be rewritten
@@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin {
113
114
let Filter {
114
115
input, predicate, ..
115
116
} = filter;
117
+
118
+ // Extract null_equals_null setting from the input join
119
+ if let LogicalPlan :: Join ( join) = input. as_ref ( ) {
120
+ null_equals_null = join. null_equals_null ;
121
+ }
122
+
116
123
flatten_join_inputs (
117
124
Arc :: unwrap_or_clone ( input) ,
118
125
& mut possible_join_keys,
@@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin {
122
129
123
130
extract_possible_join_keys ( & predicate, & mut possible_join_keys) ;
124
131
Some ( predicate)
125
- } else if matches ! (
126
- plan,
127
- LogicalPlan :: Join ( Join {
128
- join_type: JoinType :: Inner ,
129
- ..
130
- } )
131
- ) {
132
- if !can_flatten_join_inputs ( & plan) {
133
- return Ok ( Transformed :: no ( plan) ) ;
134
- }
135
- flatten_join_inputs (
136
- plan,
137
- & mut possible_join_keys,
138
- & mut all_inputs,
139
- & mut all_filters,
140
- ) ?;
141
- None
142
132
} else {
143
- // recursively try to rewrite children
144
- return rewrite_children ( self , plan, config) ;
133
+ match plan {
134
+ LogicalPlan :: Join ( Join {
135
+ join_type : JoinType :: Inner ,
136
+ null_equals_null : original_null_equals_null,
137
+ ..
138
+ } ) => {
139
+ if !can_flatten_join_inputs ( & plan) {
140
+ return Ok ( Transformed :: no ( plan) ) ;
141
+ }
142
+ flatten_join_inputs (
143
+ plan,
144
+ & mut possible_join_keys,
145
+ & mut all_inputs,
146
+ & mut all_filters,
147
+ ) ?;
148
+ null_equals_null = original_null_equals_null;
149
+ None
150
+ }
151
+ _ => {
152
+ // recursively try to rewrite children
153
+ return rewrite_children ( self , plan, config) ;
154
+ }
155
+ }
145
156
} ;
146
157
147
158
// Join keys are handled locally:
@@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
153
164
& mut all_inputs,
154
165
& possible_join_keys,
155
166
& mut all_join_keys,
167
+ null_equals_null,
156
168
) ?;
157
169
}
158
170
@@ -290,6 +302,7 @@ fn find_inner_join(
290
302
rights : & mut Vec < LogicalPlan > ,
291
303
possible_join_keys : & JoinKeySet ,
292
304
all_join_keys : & mut JoinKeySet ,
305
+ null_equals_null : bool ,
293
306
) -> Result < LogicalPlan > {
294
307
for ( i, right_input) in rights. iter ( ) . enumerate ( ) {
295
308
let mut join_keys = vec ! [ ] ;
@@ -328,7 +341,7 @@ fn find_inner_join(
328
341
on : join_keys,
329
342
filter : None ,
330
343
schema : join_schema,
331
- null_equals_null : false ,
344
+ null_equals_null,
332
345
} ) ) ;
333
346
}
334
347
}
@@ -350,7 +363,7 @@ fn find_inner_join(
350
363
filter : None ,
351
364
join_type : JoinType :: Inner ,
352
365
join_constraint : JoinConstraint :: On ,
353
- null_equals_null : false ,
366
+ null_equals_null,
354
367
} ) )
355
368
}
356
369
@@ -1333,4 +1346,69 @@ mod tests {
1333
1346
"
1334
1347
)
1335
1348
}
1349
+
1350
+ #[ test]
1351
+ fn preserve_null_equals_null_setting ( ) -> Result < ( ) > {
1352
+ let t1 = test_table_scan_with_name ( "t1" ) ?;
1353
+ let t2 = test_table_scan_with_name ( "t2" ) ?;
1354
+
1355
+ // Create an inner join with null_equals_null: true
1356
+ let join_schema = Arc :: new ( build_join_schema (
1357
+ t1. schema ( ) ,
1358
+ t2. schema ( ) ,
1359
+ & JoinType :: Inner ,
1360
+ ) ?) ;
1361
+
1362
+ let inner_join = LogicalPlan :: Join ( Join {
1363
+ left : Arc :: new ( t1) ,
1364
+ right : Arc :: new ( t2) ,
1365
+ join_type : JoinType :: Inner ,
1366
+ join_constraint : JoinConstraint :: On ,
1367
+ on : vec ! [ ] ,
1368
+ filter : None ,
1369
+ schema : join_schema,
1370
+ null_equals_null : true , // Set to true to test preservation
1371
+ } ) ;
1372
+
1373
+ // Apply filter that can create join conditions
1374
+ let plan = LogicalPlanBuilder :: from ( inner_join)
1375
+ . filter ( binary_expr (
1376
+ col ( "t1.a" ) . eq ( col ( "t2.a" ) ) ,
1377
+ And ,
1378
+ col ( "t2.c" ) . lt ( lit ( 20u32 ) ) ,
1379
+ ) ) ?
1380
+ . build ( ) ?;
1381
+
1382
+ let rule = EliminateCrossJoin :: new ( ) ;
1383
+ let optimized_plan = rule. rewrite ( plan, & OptimizerContext :: new ( ) ) ?. data ;
1384
+
1385
+ // Verify that null_equals_null is preserved in the optimized plan
1386
+ fn check_null_equals_null_preserved ( plan : & LogicalPlan ) -> bool {
1387
+ match plan {
1388
+ LogicalPlan :: Join ( join) => {
1389
+ // All joins in the optimized plan should preserve null_equals_null: true
1390
+ if !join. null_equals_null {
1391
+ return false ;
1392
+ }
1393
+ // Recursively check child plans
1394
+ plan. inputs ( )
1395
+ . iter ( )
1396
+ . all ( |input| check_null_equals_null_preserved ( input) )
1397
+ }
1398
+ _ => {
1399
+ // Recursively check child plans for non-join nodes
1400
+ plan. inputs ( )
1401
+ . iter ( )
1402
+ . all ( |input| check_null_equals_null_preserved ( input) )
1403
+ }
1404
+ }
1405
+ }
1406
+
1407
+ assert ! (
1408
+ check_null_equals_null_preserved( & optimized_plan) ,
1409
+ "null_equals_null setting should be preserved after optimization"
1410
+ ) ;
1411
+
1412
+ Ok ( ( ) )
1413
+ }
1336
1414
}
0 commit comments