Skip to content

Commit 12c4c86

Browse files
irenjjalamb
andauthored
feat: Use SchemaRef in JoinFilter (#14182)
* feat: Use `SchemaRef` in `JoinFilter` * Update datafusion/core/src/physical_optimizer/projection_pushdown.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/physical-plan/src/joins/join_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/physical-plan/src/joins/join_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/physical-plan/src/joins/join_filter.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * fix --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 3a65be0 commit 12c4c86

File tree

11 files changed

+98
-60
lines changed

11 files changed

+98
-60
lines changed

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ fn update_join_filter(
12551255
side: col_idx.side,
12561256
})
12571257
.collect(),
1258-
join_filter.schema().clone(),
1258+
Arc::clone(join_filter.schema()),
12591259
)
12601260
})
12611261
}
@@ -2246,11 +2246,11 @@ mod tests {
22462246
side: JoinSide::Left,
22472247
},
22482248
],
2249-
Schema::new(vec![
2249+
Arc::new(Schema::new(vec![
22502250
Field::new("b_left_inter", DataType::Int32, true),
22512251
Field::new("a_right_inter", DataType::Int32, true),
22522252
Field::new("c_left_inter", DataType::Int32, true),
2253-
]),
2253+
])),
22542254
)),
22552255
&JoinType::Inner,
22562256
true,
@@ -2360,11 +2360,11 @@ mod tests {
23602360
side: JoinSide::Left,
23612361
},
23622362
],
2363-
Schema::new(vec![
2363+
Arc::new(Schema::new(vec![
23642364
Field::new("b_left_inter", DataType::Int32, true),
23652365
Field::new("a_right_inter", DataType::Int32, true),
23662366
Field::new("c_left_inter", DataType::Int32, true),
2367-
]),
2367+
])),
23682368
)),
23692369
&JoinType::Inner,
23702370
true,
@@ -2462,7 +2462,7 @@ mod tests {
24622462
Some(JoinFilter::new(
24632463
filter_expr,
24642464
filter_column_indices,
2465-
filter_schema,
2465+
Arc::new(filter_schema),
24662466
)),
24672467
&JoinType::Inner,
24682468
None,
@@ -2536,11 +2536,11 @@ mod tests {
25362536
side: JoinSide::Left,
25372537
},
25382538
],
2539-
Schema::new(vec![
2539+
Arc::new(Schema::new(vec![
25402540
Field::new("b_left_inter", DataType::Int32, true),
25412541
Field::new("a_right_inter", DataType::Int32, true),
25422542
Field::new("c_left_inter", DataType::Int32, true),
2543-
]),
2543+
])),
25442544
)),
25452545
&JoinType::Inner,
25462546
None,

datafusion/core/src/physical_planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ impl DefaultPhysicalPlanner {
10901090
Some(join_utils::JoinFilter::new(
10911091
filter_expr,
10921092
column_indices,
1093-
filter_schema,
1093+
Arc::new(filter_schema),
10941094
))
10951095
}
10961096
_ => None,

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
8686
.with_nullable(true),
8787
]);
8888

89-
JoinFilter::new(less_filter, column_indices, intermediate_schema)
89+
JoinFilter::new(less_filter, column_indices, Arc::new(intermediate_schema))
9090
}
9191

9292
#[tokio::test]
@@ -327,7 +327,7 @@ impl JoinFuzzTestCase {
327327
/// on-condition schema
328328
fn intermediate_schema(&self) -> Schema {
329329
let filter_schema = if let Some(filter) = self.join_filter() {
330-
filter.schema().to_owned()
330+
filter.schema().as_ref().to_owned()
331331
} else {
332332
Schema::empty()
333333
};
@@ -483,7 +483,8 @@ impl JoinFuzzTestCase {
483483
let intermediate_schema = self.intermediate_schema();
484484
let expression = self.composite_filter_expression();
485485

486-
let filter = JoinFilter::new(expression, column_indices, intermediate_schema);
486+
let filter =
487+
JoinFilter::new(expression, column_indices, Arc::new(intermediate_schema));
487488

488489
Arc::new(
489490
NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type, None)

datafusion/physical-optimizer/src/join_selection.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ mod tests_statistical {
716716
Some(JoinFilter::new(
717717
expression,
718718
column_indices,
719-
intermediate_schema,
719+
Arc::new(intermediate_schema),
720720
))
721721
}
722722

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,7 +2576,7 @@ mod tests {
25762576
let filter = JoinFilter::new(
25772577
filter_expression,
25782578
column_indices.clone(),
2579-
intermediate_schema.clone(),
2579+
Arc::new(intermediate_schema.clone()),
25802580
);
25812581

25822582
let join = join_with_filter(
@@ -2611,8 +2611,11 @@ mod tests {
26112611
Operator::Gt,
26122612
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
26132613
)) as Arc<dyn PhysicalExpr>;
2614-
let filter =
2615-
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
2614+
let filter = JoinFilter::new(
2615+
filter_expression,
2616+
column_indices,
2617+
Arc::new(intermediate_schema),
2618+
);
26162619

26172620
let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?;
26182621

@@ -2700,7 +2703,7 @@ mod tests {
27002703
let filter = JoinFilter::new(
27012704
filter_expression,
27022705
column_indices.clone(),
2703-
intermediate_schema.clone(),
2706+
Arc::new(intermediate_schema.clone()),
27042707
);
27052708

27062709
let join = join_with_filter(
@@ -2738,8 +2741,11 @@ mod tests {
27382741
Arc::new(Literal::new(ScalarValue::Int32(Some(11)))),
27392742
)) as Arc<dyn PhysicalExpr>;
27402743

2741-
let filter =
2742-
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
2744+
let filter = JoinFilter::new(
2745+
filter_expression,
2746+
column_indices,
2747+
Arc::new(intermediate_schema.clone()),
2748+
);
27432749

27442750
let join =
27452751
join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?;
@@ -2822,7 +2828,7 @@ mod tests {
28222828
let filter = JoinFilter::new(
28232829
filter_expression,
28242830
column_indices.clone(),
2825-
intermediate_schema.clone(),
2831+
Arc::new(intermediate_schema.clone()),
28262832
);
28272833

28282834
let join = join_with_filter(
@@ -2861,8 +2867,11 @@ mod tests {
28612867
Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
28622868
)) as Arc<dyn PhysicalExpr>;
28632869

2864-
let filter =
2865-
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
2870+
let filter = JoinFilter::new(
2871+
filter_expression,
2872+
column_indices,
2873+
Arc::new(intermediate_schema),
2874+
);
28662875

28672876
let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?;
28682877

@@ -2951,7 +2960,7 @@ mod tests {
29512960
let filter = JoinFilter::new(
29522961
filter_expression,
29532962
column_indices,
2954-
intermediate_schema.clone(),
2963+
Arc::new(intermediate_schema.clone()),
29552964
);
29562965

29572966
let join = join_with_filter(
@@ -2995,8 +3004,11 @@ mod tests {
29953004
Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
29963005
)) as Arc<dyn PhysicalExpr>;
29973006

2998-
let filter =
2999-
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
3007+
let filter = JoinFilter::new(
3008+
filter_expression,
3009+
column_indices,
3010+
Arc::new(intermediate_schema),
3011+
);
30003012

30013013
let join =
30023014
join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?;
@@ -3359,7 +3371,11 @@ mod tests {
33593371
Arc::new(Column::new("c", 1)),
33603372
)) as Arc<dyn PhysicalExpr>;
33613373

3362-
JoinFilter::new(filter_expression, column_indices, intermediate_schema)
3374+
JoinFilter::new(
3375+
filter_expression,
3376+
column_indices,
3377+
Arc::new(intermediate_schema),
3378+
)
33633379
}
33643380

33653381
#[apply(batch_sizes)]

datafusion/physical-plan/src/joins/join_filter.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::joins::utils::ColumnIndex;
19-
use arrow_schema::Schema;
19+
use arrow_schema::SchemaRef;
2020
use datafusion_common::JoinSide;
2121
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2222
use std::sync::Arc;
@@ -30,15 +30,15 @@ pub struct JoinFilter {
3030
/// Column indices required to construct intermediate batch for filtering
3131
pub(crate) column_indices: Vec<ColumnIndex>,
3232
/// Physical schema of intermediate batch
33-
pub(crate) schema: Schema,
33+
pub(crate) schema: SchemaRef,
3434
}
3535

3636
impl JoinFilter {
3737
/// Creates new JoinFilter
3838
pub fn new(
3939
expression: Arc<dyn PhysicalExpr>,
4040
column_indices: Vec<ColumnIndex>,
41-
schema: Schema,
41+
schema: SchemaRef,
4242
) -> JoinFilter {
4343
JoinFilter {
4444
expression,
@@ -76,7 +76,7 @@ impl JoinFilter {
7676
}
7777

7878
/// Intermediate batch schema
79-
pub fn schema(&self) -> &Schema {
79+
pub fn schema(&self) -> &SchemaRef {
8080
&self.schema
8181
}
8282

@@ -94,7 +94,7 @@ impl JoinFilter {
9494
JoinFilter::new(
9595
Arc::clone(self.expression()),
9696
column_indices,
97-
self.schema().clone(),
97+
Arc::clone(self.schema()),
9898
)
9999
}
100100
}

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,11 @@ pub(crate) mod tests {
11041104
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
11051105
as Arc<dyn PhysicalExpr>;
11061106

1107-
JoinFilter::new(filter_expression, column_indices, intermediate_schema)
1107+
JoinFilter::new(
1108+
filter_expression,
1109+
column_indices,
1110+
Arc::new(intermediate_schema),
1111+
)
11081112
}
11091113

11101114
pub(crate) async fn multi_partitioned_join_collect(
@@ -1514,7 +1518,11 @@ pub(crate) mod tests {
15141518
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
15151519
as Arc<dyn PhysicalExpr>;
15161520

1517-
JoinFilter::new(filter_expression, column_indices, intermediate_schema)
1521+
JoinFilter::new(
1522+
filter_expression,
1523+
column_indices,
1524+
Arc::new(intermediate_schema),
1525+
)
15181526
}
15191527

15201528
fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,10 +1759,8 @@ impl SortMergeJoinStream {
17591759
if !filter_columns.is_empty() {
17601760
if let Some(f) = &self.filter {
17611761
// Construct batch with only filter columns
1762-
let filter_batch = RecordBatch::try_new(
1763-
Arc::new(f.schema().clone()),
1764-
filter_columns,
1765-
)?;
1762+
let filter_batch =
1763+
RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;
17661764

17671765
let filter_result = f
17681766
.expression()
@@ -3182,10 +3180,10 @@ mod tests {
31823180
side: JoinSide::Right,
31833181
},
31843182
],
3185-
Schema::new(vec![
3183+
Arc::new(Schema::new(vec![
31863184
Field::new("c1", DataType::Int32, true),
31873185
Field::new("c2", DataType::Int32, true),
3188-
]),
3186+
])),
31893187
);
31903188
let (_, batches) =
31913189
join_collect_with_filter(left, right, on, filter, RightAnti).await?;

datafusion/physical-plan/src/joins/stream_join_utils.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,8 @@ pub mod tests {
856856
side: JoinSide::Right,
857857
},
858858
];
859-
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
859+
let filter =
860+
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
860861

861862
let left_sort_filter_expr = build_filter_input_order(
862863
JoinSide::Left,
@@ -983,7 +984,8 @@ pub mod tests {
983984
side: JoinSide::Right,
984985
},
985986
];
986-
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
987+
let filter =
988+
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
987989

988990
let left_schema = Arc::new(left_schema);
989991
let right_schema = Arc::new(right_schema);
@@ -1055,7 +1057,8 @@ pub mod tests {
10551057
side: JoinSide::Left,
10561058
},
10571059
];
1058-
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
1060+
let filter =
1061+
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
10591062

10601063
let schema = Schema::new(vec![
10611064
Field::new("a", DataType::Int32, false),

0 commit comments

Comments
 (0)