Skip to content

Commit c3345a5

Browse files
compheadalamb
authored andcommitted
Fix: Sort Merge Join LeftSemi issues when JoinFilter is set (apache#10304)
* Fix: Sort Merge Join Left Semi crashes Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent f287035 commit c3345a5

File tree

2 files changed

+352
-29
lines changed

2 files changed

+352
-29
lines changed

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 201 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,13 @@ use std::pin::Pin;
3030
use std::sync::Arc;
3131
use std::task::{Context, Poll};
3232

33-
use crate::expressions::PhysicalSortExpr;
34-
use crate::joins::utils::{
35-
build_join_schema, check_join_is_valid, estimate_join_statistics,
36-
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
37-
};
38-
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
39-
use crate::{
40-
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
41-
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
42-
RecordBatchStream, SendableRecordBatchStream, Statistics,
43-
};
44-
4533
use arrow::array::*;
4634
use arrow::compute::{self, concat_batches, take, SortOptions};
4735
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
4836
use arrow::error::ArrowError;
37+
use futures::{Stream, StreamExt};
38+
use hashbrown::HashSet;
39+
4940
use datafusion_common::{
5041
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
5142
};
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
5445
use datafusion_physical_expr::equivalence::join_equivalence_properties;
5546
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
5647

57-
use futures::{Stream, StreamExt};
48+
use crate::expressions::PhysicalSortExpr;
49+
use crate::joins::utils::{
50+
build_join_schema, check_join_is_valid, estimate_join_statistics,
51+
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
52+
};
53+
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
54+
use crate::{
55+
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
56+
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
57+
RecordBatchStream, SendableRecordBatchStream, Statistics,
58+
};
5859

5960
/// join execution plan executes partitions in parallel and combines them into a set of
6061
/// partitions.
@@ -491,6 +492,10 @@ struct StreamedBatch {
491492
pub output_indices: Vec<StreamedJoinedChunk>,
492493
/// Index of currently scanned batch from buffered data
493494
pub buffered_batch_idx: Option<usize>,
495+
/// Indices that found a match for the given join filter
496+
/// Used for semi joins to keep track the streaming index which got a join filter match
497+
/// and already emitted to the output.
498+
pub join_filter_matched_idxs: HashSet<u64>,
494499
}
495500

496501
impl StreamedBatch {
@@ -502,6 +507,7 @@ impl StreamedBatch {
502507
join_arrays,
503508
output_indices: vec![],
504509
buffered_batch_idx: None,
510+
join_filter_matched_idxs: HashSet::new(),
505511
}
506512
}
507513

@@ -512,6 +518,7 @@ impl StreamedBatch {
512518
join_arrays: vec![],
513519
output_indices: vec![],
514520
buffered_batch_idx: None,
521+
join_filter_matched_idxs: HashSet::new(),
515522
}
516523
}
517524

@@ -990,7 +997,22 @@ impl SMJStream {
990997
}
991998
Ordering::Equal => {
992999
if matches!(self.join_type, JoinType::LeftSemi) {
993-
join_streamed = !self.streamed_joined;
1000+
// if the join filter is specified then its needed to output the streamed index
1001+
// only if it has not been emitted before
1002+
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
1003+
// filter match and prevents the same index to go into output more than once
1004+
if self.filter.is_some() {
1005+
join_streamed = !self
1006+
.streamed_batch
1007+
.join_filter_matched_idxs
1008+
.contains(&(self.streamed_batch.idx as u64))
1009+
&& !self.streamed_joined;
1010+
// if the join filter specified there can be references to buffered columns
1011+
// so buffered columns are needed to access them
1012+
join_buffered = join_streamed;
1013+
} else {
1014+
join_streamed = !self.streamed_joined;
1015+
}
9941016
}
9951017
if matches!(
9961018
self.join_type,
@@ -1134,17 +1156,15 @@ impl SMJStream {
11341156
.collect::<Result<Vec<_>, ArrowError>>()?;
11351157

11361158
let buffered_indices: UInt64Array = chunk.buffered_indices.finish();
1137-
11381159
let mut buffered_columns =
11391160
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
11401161
vec![]
11411162
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1142-
self.buffered_data.batches[buffered_idx]
1143-
.batch
1144-
.columns()
1145-
.iter()
1146-
.map(|column| take(column, &buffered_indices, None))
1147-
.collect::<Result<Vec<_>, ArrowError>>()?
1163+
get_buffered_columns(
1164+
&self.buffered_data,
1165+
buffered_idx,
1166+
&buffered_indices,
1167+
)?
11481168
} else {
11491169
self.buffered_schema
11501170
.fields()
@@ -1161,6 +1181,15 @@ impl SMJStream {
11611181
let filter_columns = if chunk.buffered_batch_idx.is_some() {
11621182
if matches!(self.join_type, JoinType::Right) {
11631183
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
1184+
} else if matches!(self.join_type, JoinType::LeftSemi) {
1185+
// unwrap is safe here as we check is_some on top of if statement
1186+
let buffered_columns = get_buffered_columns(
1187+
&self.buffered_data,
1188+
chunk.buffered_batch_idx.unwrap(),
1189+
&buffered_indices,
1190+
)?;
1191+
1192+
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
11641193
} else {
11651194
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
11661195
}
@@ -1195,7 +1224,17 @@ impl SMJStream {
11951224
.into_array(filter_batch.num_rows())?;
11961225

11971226
// The selection mask of the filter
1198-
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
1227+
let mut mask =
1228+
datafusion_common::cast::as_boolean_array(&filter_result)?;
1229+
1230+
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
1231+
get_filtered_join_mask(self.join_type, streamed_indices, mask);
1232+
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
1233+
mask = &filtered_join_mask.0;
1234+
self.streamed_batch
1235+
.join_filter_matched_idxs
1236+
.extend(&filtered_join_mask.1);
1237+
}
11991238

12001239
// Push the filtered batch to the output
12011240
let filtered_batch =
@@ -1365,6 +1404,69 @@ fn get_filter_column(
13651404
filter_columns
13661405
}
13671406

1407+
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
1408+
#[inline(always)]
1409+
fn get_buffered_columns(
1410+
buffered_data: &BufferedData,
1411+
buffered_batch_idx: usize,
1412+
buffered_indices: &UInt64Array,
1413+
) -> Result<Vec<ArrayRef>, ArrowError> {
1414+
buffered_data.batches[buffered_batch_idx]
1415+
.batch
1416+
.columns()
1417+
.iter()
1418+
.map(|column| take(column, &buffered_indices, None))
1419+
.collect::<Result<Vec<_>, ArrowError>>()
1420+
}
1421+
1422+
// Calculate join filter bit mask considering join type specifics
1423+
// `streamed_indices` - array of streamed datasource JOINED row indices
1424+
// `mask` - array booleans representing computed join filter expression eval result:
1425+
// true = the row index matches the join filter
1426+
// false = the row index doesn't match the join filter
1427+
// `streamed_indices` have the same length as `mask`
1428+
fn get_filtered_join_mask(
1429+
join_type: JoinType,
1430+
streamed_indices: UInt64Array,
1431+
mask: &BooleanArray,
1432+
) -> Option<(BooleanArray, Vec<u64>)> {
1433+
// for LeftSemi Join the filter mask should be calculated in its own way:
1434+
// if we find at least one matching row for specific streaming index
1435+
// we don't need to check any others for the same index
1436+
if matches!(join_type, JoinType::LeftSemi) {
1437+
// have we seen a filter match for a streaming index before
1438+
let mut seen_as_true: bool = false;
1439+
let streamed_indices_length = streamed_indices.len();
1440+
let mut corrected_mask: BooleanBuilder =
1441+
BooleanBuilder::with_capacity(streamed_indices_length);
1442+
1443+
let mut filter_matched_indices: Vec<u64> = vec![];
1444+
1445+
#[allow(clippy::needless_range_loop)]
1446+
for i in 0..streamed_indices_length {
1447+
// LeftSemi respects only first true values for specific streaming index,
1448+
// others true values for the same index must be false
1449+
if mask.value(i) && !seen_as_true {
1450+
seen_as_true = true;
1451+
corrected_mask.append_value(true);
1452+
filter_matched_indices.push(streamed_indices.value(i));
1453+
} else {
1454+
corrected_mask.append_value(false);
1455+
}
1456+
1457+
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1458+
if i < streamed_indices_length - 1
1459+
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
1460+
{
1461+
seen_as_true = false;
1462+
}
1463+
}
1464+
Some((corrected_mask.finish(), filter_matched_indices))
1465+
} else {
1466+
None
1467+
}
1468+
}
1469+
13681470
/// Buffered data contains all buffered batches with one unique join key
13691471
#[derive(Debug, Default)]
13701472
struct BufferedData {
@@ -1604,24 +1706,28 @@ fn is_join_arrays_equal(
16041706
mod tests {
16051707
use std::sync::Arc;
16061708

1607-
use crate::expressions::Column;
1608-
use crate::joins::utils::JoinOn;
1609-
use crate::joins::SortMergeJoinExec;
1610-
use crate::memory::MemoryExec;
1611-
use crate::test::build_table_i32;
1612-
use crate::{common, ExecutionPlan};
1613-
16141709
use arrow::array::{Date32Array, Date64Array, Int32Array};
16151710
use arrow::compute::SortOptions;
16161711
use arrow::datatypes::{DataType, Field, Schema};
16171712
use arrow::record_batch::RecordBatch;
1713+
use arrow_array::{BooleanArray, UInt64Array};
1714+
1715+
use datafusion_common::JoinType::LeftSemi;
16181716
use datafusion_common::{
16191717
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
16201718
};
16211719
use datafusion_execution::config::SessionConfig;
16221720
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
16231721
use datafusion_execution::TaskContext;
16241722

1723+
use crate::expressions::Column;
1724+
use crate::joins::sort_merge_join::get_filtered_join_mask;
1725+
use crate::joins::utils::JoinOn;
1726+
use crate::joins::SortMergeJoinExec;
1727+
use crate::memory::MemoryExec;
1728+
use crate::test::build_table_i32;
1729+
use crate::{common, ExecutionPlan};
1730+
16251731
fn build_table(
16261732
a: (&str, &Vec<i32>),
16271733
b: (&str, &Vec<i32>),
@@ -2641,6 +2747,72 @@ mod tests {
26412747

26422748
Ok(())
26432749
}
2750+
2751+
#[tokio::test]
2752+
async fn left_semi_join_filtered_mask() -> Result<()> {
2753+
assert_eq!(
2754+
get_filtered_join_mask(
2755+
LeftSemi,
2756+
UInt64Array::from(vec![0, 0, 1, 1]),
2757+
&BooleanArray::from(vec![true, true, false, false])
2758+
),
2759+
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
2760+
);
2761+
2762+
assert_eq!(
2763+
get_filtered_join_mask(
2764+
LeftSemi,
2765+
UInt64Array::from(vec![0, 1]),
2766+
&BooleanArray::from(vec![true, true])
2767+
),
2768+
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
2769+
);
2770+
2771+
assert_eq!(
2772+
get_filtered_join_mask(
2773+
LeftSemi,
2774+
UInt64Array::from(vec![0, 1]),
2775+
&BooleanArray::from(vec![false, true])
2776+
),
2777+
Some((BooleanArray::from(vec![false, true]), vec![1]))
2778+
);
2779+
2780+
assert_eq!(
2781+
get_filtered_join_mask(
2782+
LeftSemi,
2783+
UInt64Array::from(vec![0, 1]),
2784+
&BooleanArray::from(vec![true, false])
2785+
),
2786+
Some((BooleanArray::from(vec![true, false]), vec![0]))
2787+
);
2788+
2789+
assert_eq!(
2790+
get_filtered_join_mask(
2791+
LeftSemi,
2792+
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2793+
&BooleanArray::from(vec![false, true, true, true, true, true])
2794+
),
2795+
Some((
2796+
BooleanArray::from(vec![false, true, false, true, false, false]),
2797+
vec![0, 1]
2798+
))
2799+
);
2800+
2801+
assert_eq!(
2802+
get_filtered_join_mask(
2803+
LeftSemi,
2804+
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
2805+
&BooleanArray::from(vec![false, false, false, false, false, true])
2806+
),
2807+
Some((
2808+
BooleanArray::from(vec![false, false, false, false, false, true]),
2809+
vec![1]
2810+
))
2811+
);
2812+
2813+
Ok(())
2814+
}
2815+
26442816
/// Returns the column names on the schema
26452817
fn columns(schema: &Schema) -> Vec<String> {
26462818
schema.fields().iter().map(|f| f.name().clone()).collect()

0 commit comments

Comments
 (0)