Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ fn update_join_filter(
side: col_idx.side,
})
.collect(),
join_filter.schema().clone(),
Arc::clone(join_filter.schema()),
)
})
}
Expand Down Expand Up @@ -2246,11 +2246,11 @@ mod tests {
side: JoinSide::Left,
},
],
Schema::new(vec![
Arc::new(Schema::new(vec![
Field::new("b_left_inter", DataType::Int32, true),
Field::new("a_right_inter", DataType::Int32, true),
Field::new("c_left_inter", DataType::Int32, true),
]),
])),
)),
&JoinType::Inner,
true,
Expand Down Expand Up @@ -2360,11 +2360,11 @@ mod tests {
side: JoinSide::Left,
},
],
Schema::new(vec![
Arc::new(Schema::new(vec![
Field::new("b_left_inter", DataType::Int32, true),
Field::new("a_right_inter", DataType::Int32, true),
Field::new("c_left_inter", DataType::Int32, true),
]),
])),
)),
&JoinType::Inner,
true,
Expand Down Expand Up @@ -2462,7 +2462,7 @@ mod tests {
Some(JoinFilter::new(
filter_expr,
filter_column_indices,
filter_schema,
Arc::new(filter_schema),
)),
&JoinType::Inner,
None,
Expand Down Expand Up @@ -2536,11 +2536,11 @@ mod tests {
side: JoinSide::Left,
},
],
Schema::new(vec![
Arc::new(Schema::new(vec![
Field::new("b_left_inter", DataType::Int32, true),
Field::new("a_right_inter", DataType::Int32, true),
Field::new("c_left_inter", DataType::Int32, true),
]),
])),
)),
&JoinType::Inner,
None,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ impl DefaultPhysicalPlanner {
Some(join_utils::JoinFilter::new(
filter_expr,
column_indices,
filter_schema,
Arc::new(filter_schema),
))
}
_ => None,
Expand Down
7 changes: 4 additions & 3 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
.with_nullable(true),
]);

JoinFilter::new(less_filter, column_indices, intermediate_schema)
JoinFilter::new(less_filter, column_indices, Arc::new(intermediate_schema))
}

#[tokio::test]
Expand Down Expand Up @@ -327,7 +327,7 @@ impl JoinFuzzTestCase {
/// on-condition schema
fn intermediate_schema(&self) -> Schema {
let filter_schema = if let Some(filter) = self.join_filter() {
filter.schema().to_owned()
filter.schema().as_ref().to_owned()
} else {
Schema::empty()
};
Expand Down Expand Up @@ -483,7 +483,8 @@ impl JoinFuzzTestCase {
let intermediate_schema = self.intermediate_schema();
let expression = self.composite_filter_expression();

let filter = JoinFilter::new(expression, column_indices, intermediate_schema);
let filter =
JoinFilter::new(expression, column_indices, Arc::new(intermediate_schema));

Arc::new(
NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type, None)
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-optimizer/src/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ mod tests_statistical {
Some(JoinFilter::new(
expression,
column_indices,
intermediate_schema,
Arc::new(intermediate_schema),
))
}

Expand Down
42 changes: 29 additions & 13 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ mod tests {
let filter = JoinFilter::new(
filter_expression,
column_indices.clone(),
intermediate_schema.clone(),
Arc::new(intermediate_schema.clone()),
);

let join = join_with_filter(
Expand Down Expand Up @@ -2611,8 +2611,11 @@ mod tests {
Operator::Gt,
Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
)) as Arc<dyn PhysicalExpr>;
let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
);

let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?;

Expand Down Expand Up @@ -2700,7 +2703,7 @@ mod tests {
let filter = JoinFilter::new(
filter_expression,
column_indices.clone(),
intermediate_schema.clone(),
Arc::new(intermediate_schema.clone()),
);

let join = join_with_filter(
Expand Down Expand Up @@ -2738,8 +2741,11 @@ mod tests {
Arc::new(Literal::new(ScalarValue::Int32(Some(11)))),
)) as Arc<dyn PhysicalExpr>;

let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema.clone()),
);

let join =
join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?;
Expand Down Expand Up @@ -2822,7 +2828,7 @@ mod tests {
let filter = JoinFilter::new(
filter_expression,
column_indices.clone(),
intermediate_schema.clone(),
Arc::new(intermediate_schema.clone()),
);

let join = join_with_filter(
Expand Down Expand Up @@ -2861,8 +2867,11 @@ mod tests {
Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
)) as Arc<dyn PhysicalExpr>;

let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
);

let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?;

Expand Down Expand Up @@ -2951,7 +2960,7 @@ mod tests {
let filter = JoinFilter::new(
filter_expression,
column_indices,
intermediate_schema.clone(),
Arc::new(intermediate_schema.clone()),
);

let join = join_with_filter(
Expand Down Expand Up @@ -2995,8 +3004,11 @@ mod tests {
Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
)) as Arc<dyn PhysicalExpr>;

let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);
let filter = JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
);

let join =
join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?;
Expand Down Expand Up @@ -3359,7 +3371,11 @@ mod tests {
Arc::new(Column::new("c", 1)),
)) as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
)
}

#[apply(batch_sizes)]
Expand Down
10 changes: 5 additions & 5 deletions datafusion/physical-plan/src/joins/join_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::joins::utils::ColumnIndex;
use arrow_schema::Schema;
use arrow_schema::SchemaRef;
use datafusion_common::JoinSide;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;
Expand All @@ -30,15 +30,15 @@ pub struct JoinFilter {
/// Column indices required to construct intermediate batch for filtering
pub(crate) column_indices: Vec<ColumnIndex>,
/// Physical schema of intermediate batch
pub(crate) schema: Schema,
pub(crate) schema: SchemaRef,
}

impl JoinFilter {
/// Creates new JoinFilter
pub fn new(
expression: Arc<dyn PhysicalExpr>,
column_indices: Vec<ColumnIndex>,
schema: Schema,
schema: SchemaRef,
) -> JoinFilter {
JoinFilter {
expression,
Expand Down Expand Up @@ -76,7 +76,7 @@ impl JoinFilter {
}

/// Intermediate batch schema
pub fn schema(&self) -> &Schema {
pub fn schema(&self) -> &SchemaRef {
&self.schema
}

Expand All @@ -94,7 +94,7 @@ impl JoinFilter {
JoinFilter::new(
Arc::clone(self.expression()),
column_indices,
self.schema().clone(),
Arc::clone(self.schema()),
)
}
}
12 changes: 10 additions & 2 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,11 @@ pub(crate) mod tests {
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
)
}

pub(crate) async fn multi_partitioned_join_collect(
Expand Down Expand Up @@ -1514,7 +1518,11 @@ pub(crate) mod tests {
Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
as Arc<dyn PhysicalExpr>;

JoinFilter::new(filter_expression, column_indices, intermediate_schema)
JoinFilter::new(
filter_expression,
column_indices,
Arc::new(intermediate_schema),
)
}

fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
Expand Down
10 changes: 4 additions & 6 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1759,10 +1759,8 @@ impl SortMergeJoinStream {
if !filter_columns.is_empty() {
if let Some(f) = &self.filter {
// Construct batch with only filter columns
let filter_batch = RecordBatch::try_new(
Arc::new(f.schema().clone()),
filter_columns,
)?;
let filter_batch =
RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;

let filter_result = f
.expression()
Expand Down Expand Up @@ -3182,10 +3180,10 @@ mod tests {
side: JoinSide::Right,
},
],
Schema::new(vec![
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]),
])),
);
let (_, batches) =
join_collect_with_filter(left, right, on, filter, RightAnti).await?;
Expand Down
9 changes: 6 additions & 3 deletions datafusion/physical-plan/src/joins/stream_join_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,8 @@ pub mod tests {
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));

let left_sort_filter_expr = build_filter_input_order(
JoinSide::Left,
Expand Down Expand Up @@ -983,7 +984,8 @@ pub mod tests {
side: JoinSide::Right,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));

let left_schema = Arc::new(left_schema);
let right_schema = Arc::new(right_schema);
Expand Down Expand Up @@ -1055,7 +1057,8 @@ pub mod tests {
side: JoinSide::Left,
},
];
let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
let filter =
JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));

let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Expand Down
Loading
Loading