@@ -30,22 +30,13 @@ use std::pin::Pin;
3030use std:: sync:: Arc ;
3131use std:: task:: { Context , Poll } ;
3232
33- use crate :: expressions:: PhysicalSortExpr ;
34- use crate :: joins:: utils:: {
35- build_join_schema, check_join_is_valid, estimate_join_statistics,
36- partitioned_join_output_partitioning, JoinFilter , JoinOn , JoinOnRef ,
37- } ;
38- use crate :: metrics:: { ExecutionPlanMetricsSet , MetricBuilder , MetricsSet } ;
39- use crate :: {
40- execution_mode_from_children, metrics, DisplayAs , DisplayFormatType , Distribution ,
41- ExecutionPlan , ExecutionPlanProperties , PhysicalExpr , PlanProperties ,
42- RecordBatchStream , SendableRecordBatchStream , Statistics ,
43- } ;
44-
4533use arrow:: array:: * ;
4634use arrow:: compute:: { self , concat_batches, take, SortOptions } ;
4735use arrow:: datatypes:: { DataType , SchemaRef , TimeUnit } ;
4836use arrow:: error:: ArrowError ;
37+ use futures:: { Stream , StreamExt } ;
38+ use hashbrown:: HashSet ;
39+
4940use datafusion_common:: {
5041 internal_err, not_impl_err, plan_err, DataFusionError , JoinSide , JoinType , Result ,
5142} ;
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
5445use datafusion_physical_expr:: equivalence:: join_equivalence_properties;
5546use datafusion_physical_expr:: { PhysicalExprRef , PhysicalSortRequirement } ;
5647
57- use futures:: { Stream , StreamExt } ;
48+ use crate :: expressions:: PhysicalSortExpr ;
49+ use crate :: joins:: utils:: {
50+ build_join_schema, check_join_is_valid, estimate_join_statistics,
51+ partitioned_join_output_partitioning, JoinFilter , JoinOn , JoinOnRef ,
52+ } ;
53+ use crate :: metrics:: { ExecutionPlanMetricsSet , MetricBuilder , MetricsSet } ;
54+ use crate :: {
55+ execution_mode_from_children, metrics, DisplayAs , DisplayFormatType , Distribution ,
56+ ExecutionPlan , ExecutionPlanProperties , PhysicalExpr , PlanProperties ,
57+ RecordBatchStream , SendableRecordBatchStream , Statistics ,
58+ } ;
5859
5960/// join execution plan executes partitions in parallel and combines them into a set of
6061/// partitions.
@@ -491,6 +492,10 @@ struct StreamedBatch {
491492 pub output_indices : Vec < StreamedJoinedChunk > ,
492493 /// Index of currently scanned batch from buffered data
493494 pub buffered_batch_idx : Option < usize > ,
495+ /// Indices that found a match for the given join filter
496+ /// Used for semi joins to keep track the streaming index which got a join filter match
497+ /// and already emitted to the output.
498+ pub join_filter_matched_idxs : HashSet < u64 > ,
494499}
495500
496501impl StreamedBatch {
@@ -502,6 +507,7 @@ impl StreamedBatch {
502507 join_arrays,
503508 output_indices : vec ! [ ] ,
504509 buffered_batch_idx : None ,
510+ join_filter_matched_idxs : HashSet :: new ( ) ,
505511 }
506512 }
507513
@@ -512,6 +518,7 @@ impl StreamedBatch {
512518 join_arrays : vec ! [ ] ,
513519 output_indices : vec ! [ ] ,
514520 buffered_batch_idx : None ,
521+ join_filter_matched_idxs : HashSet :: new ( ) ,
515522 }
516523 }
517524
@@ -990,7 +997,22 @@ impl SMJStream {
990997 }
991998 Ordering :: Equal => {
992999 if matches ! ( self . join_type, JoinType :: LeftSemi ) {
993- join_streamed = !self . streamed_joined ;
1000+ // if the join filter is specified then its needed to output the streamed index
1001+ // only if it has not been emitted before
1002+ // the `join_filter_matched_idxs` keeps track on if streamed index has a successful
1003+ // filter match and prevents the same index to go into output more than once
1004+ if self . filter . is_some ( ) {
1005+ join_streamed = !self
1006+ . streamed_batch
1007+ . join_filter_matched_idxs
1008+ . contains ( & ( self . streamed_batch . idx as u64 ) )
1009+ && !self . streamed_joined ;
1010+ // if the join filter specified there can be references to buffered columns
1011+ // so buffered columns are needed to access them
1012+ join_buffered = join_streamed;
1013+ } else {
1014+ join_streamed = !self . streamed_joined ;
1015+ }
9941016 }
9951017 if matches ! (
9961018 self . join_type,
@@ -1134,17 +1156,15 @@ impl SMJStream {
11341156 . collect :: < Result < Vec < _ > , ArrowError > > ( ) ?;
11351157
11361158 let buffered_indices: UInt64Array = chunk. buffered_indices . finish ( ) ;
1137-
11381159 let mut buffered_columns =
11391160 if matches ! ( self . join_type, JoinType :: LeftSemi | JoinType :: LeftAnti ) {
11401161 vec ! [ ]
11411162 } else if let Some ( buffered_idx) = chunk. buffered_batch_idx {
1142- self . buffered_data . batches [ buffered_idx]
1143- . batch
1144- . columns ( )
1145- . iter ( )
1146- . map ( |column| take ( column, & buffered_indices, None ) )
1147- . collect :: < Result < Vec < _ > , ArrowError > > ( ) ?
1163+ get_buffered_columns (
1164+ & self . buffered_data ,
1165+ buffered_idx,
1166+ & buffered_indices,
1167+ ) ?
11481168 } else {
11491169 self . buffered_schema
11501170 . fields ( )
@@ -1161,6 +1181,15 @@ impl SMJStream {
11611181 let filter_columns = if chunk. buffered_batch_idx . is_some ( ) {
11621182 if matches ! ( self . join_type, JoinType :: Right ) {
11631183 get_filter_column ( & self . filter , & buffered_columns, & streamed_columns)
1184+ } else if matches ! ( self . join_type, JoinType :: LeftSemi ) {
1185+ // unwrap is safe here as we check is_some on top of if statement
1186+ let buffered_columns = get_buffered_columns (
1187+ & self . buffered_data ,
1188+ chunk. buffered_batch_idx . unwrap ( ) ,
1189+ & buffered_indices,
1190+ ) ?;
1191+
1192+ get_filter_column ( & self . filter , & streamed_columns, & buffered_columns)
11641193 } else {
11651194 get_filter_column ( & self . filter , & streamed_columns, & buffered_columns)
11661195 }
@@ -1195,7 +1224,17 @@ impl SMJStream {
11951224 . into_array ( filter_batch. num_rows ( ) ) ?;
11961225
11971226 // The selection mask of the filter
1198- let mask = datafusion_common:: cast:: as_boolean_array ( & filter_result) ?;
1227+ let mut mask =
1228+ datafusion_common:: cast:: as_boolean_array ( & filter_result) ?;
1229+
1230+ let maybe_filtered_join_mask: Option < ( BooleanArray , Vec < u64 > ) > =
1231+ get_filtered_join_mask ( self . join_type , streamed_indices, mask) ;
1232+ if let Some ( ref filtered_join_mask) = maybe_filtered_join_mask {
1233+ mask = & filtered_join_mask. 0 ;
1234+ self . streamed_batch
1235+ . join_filter_matched_idxs
1236+ . extend ( & filtered_join_mask. 1 ) ;
1237+ }
11991238
12001239 // Push the filtered batch to the output
12011240 let filtered_batch =
@@ -1365,6 +1404,69 @@ fn get_filter_column(
13651404 filter_columns
13661405}
13671406
1407+ /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
1408+ #[ inline( always) ]
1409+ fn get_buffered_columns (
1410+ buffered_data : & BufferedData ,
1411+ buffered_batch_idx : usize ,
1412+ buffered_indices : & UInt64Array ,
1413+ ) -> Result < Vec < ArrayRef > , ArrowError > {
1414+ buffered_data. batches [ buffered_batch_idx]
1415+ . batch
1416+ . columns ( )
1417+ . iter ( )
1418+ . map ( |column| take ( column, & buffered_indices, None ) )
1419+ . collect :: < Result < Vec < _ > , ArrowError > > ( )
1420+ }
1421+
1422+ // Calculate join filter bit mask considering join type specifics
1423+ // `streamed_indices` - array of streamed datasource JOINED row indices
1424+ // `mask` - array booleans representing computed join filter expression eval result:
1425+ // true = the row index matches the join filter
1426+ // false = the row index doesn't match the join filter
1427+ // `streamed_indices` have the same length as `mask`
1428+ fn get_filtered_join_mask (
1429+ join_type : JoinType ,
1430+ streamed_indices : UInt64Array ,
1431+ mask : & BooleanArray ,
1432+ ) -> Option < ( BooleanArray , Vec < u64 > ) > {
1433+ // for LeftSemi Join the filter mask should be calculated in its own way:
1434+ // if we find at least one matching row for specific streaming index
1435+ // we don't need to check any others for the same index
1436+ if matches ! ( join_type, JoinType :: LeftSemi ) {
1437+ // have we seen a filter match for a streaming index before
1438+ let mut seen_as_true: bool = false ;
1439+ let streamed_indices_length = streamed_indices. len ( ) ;
1440+ let mut corrected_mask: BooleanBuilder =
1441+ BooleanBuilder :: with_capacity ( streamed_indices_length) ;
1442+
1443+ let mut filter_matched_indices: Vec < u64 > = vec ! [ ] ;
1444+
1445+ #[ allow( clippy:: needless_range_loop) ]
1446+ for i in 0 ..streamed_indices_length {
1447+ // LeftSemi respects only first true values for specific streaming index,
1448+ // others true values for the same index must be false
1449+ if mask. value ( i) && !seen_as_true {
1450+ seen_as_true = true ;
1451+ corrected_mask. append_value ( true ) ;
1452+ filter_matched_indices. push ( streamed_indices. value ( i) ) ;
1453+ } else {
1454+ corrected_mask. append_value ( false ) ;
1455+ }
1456+
1457+ // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1458+ if i < streamed_indices_length - 1
1459+ && streamed_indices. value ( i) != streamed_indices. value ( i + 1 )
1460+ {
1461+ seen_as_true = false ;
1462+ }
1463+ }
1464+ Some ( ( corrected_mask. finish ( ) , filter_matched_indices) )
1465+ } else {
1466+ None
1467+ }
1468+ }
1469+
13681470/// Buffered data contains all buffered batches with one unique join key
13691471#[ derive( Debug , Default ) ]
13701472struct BufferedData {
@@ -1604,24 +1706,28 @@ fn is_join_arrays_equal(
16041706mod tests {
16051707 use std:: sync:: Arc ;
16061708
1607- use crate :: expressions:: Column ;
1608- use crate :: joins:: utils:: JoinOn ;
1609- use crate :: joins:: SortMergeJoinExec ;
1610- use crate :: memory:: MemoryExec ;
1611- use crate :: test:: build_table_i32;
1612- use crate :: { common, ExecutionPlan } ;
1613-
16141709 use arrow:: array:: { Date32Array , Date64Array , Int32Array } ;
16151710 use arrow:: compute:: SortOptions ;
16161711 use arrow:: datatypes:: { DataType , Field , Schema } ;
16171712 use arrow:: record_batch:: RecordBatch ;
1713+ use arrow_array:: { BooleanArray , UInt64Array } ;
1714+
1715+ use datafusion_common:: JoinType :: LeftSemi ;
16181716 use datafusion_common:: {
16191717 assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType , Result ,
16201718 } ;
16211719 use datafusion_execution:: config:: SessionConfig ;
16221720 use datafusion_execution:: runtime_env:: { RuntimeConfig , RuntimeEnv } ;
16231721 use datafusion_execution:: TaskContext ;
16241722
1723+ use crate :: expressions:: Column ;
1724+ use crate :: joins:: sort_merge_join:: get_filtered_join_mask;
1725+ use crate :: joins:: utils:: JoinOn ;
1726+ use crate :: joins:: SortMergeJoinExec ;
1727+ use crate :: memory:: MemoryExec ;
1728+ use crate :: test:: build_table_i32;
1729+ use crate :: { common, ExecutionPlan } ;
1730+
16251731 fn build_table (
16261732 a : ( & str , & Vec < i32 > ) ,
16271733 b : ( & str , & Vec < i32 > ) ,
@@ -2641,6 +2747,72 @@ mod tests {
26412747
26422748 Ok ( ( ) )
26432749 }
2750+
2751+ #[ tokio:: test]
2752+ async fn left_semi_join_filtered_mask ( ) -> Result < ( ) > {
2753+ assert_eq ! (
2754+ get_filtered_join_mask(
2755+ LeftSemi ,
2756+ UInt64Array :: from( vec![ 0 , 0 , 1 , 1 ] ) ,
2757+ & BooleanArray :: from( vec![ true , true , false , false ] )
2758+ ) ,
2759+ Some ( ( BooleanArray :: from( vec![ true , false , false , false ] ) , vec![ 0 ] ) )
2760+ ) ;
2761+
2762+ assert_eq ! (
2763+ get_filtered_join_mask(
2764+ LeftSemi ,
2765+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2766+ & BooleanArray :: from( vec![ true , true ] )
2767+ ) ,
2768+ Some ( ( BooleanArray :: from( vec![ true , true ] ) , vec![ 0 , 1 ] ) )
2769+ ) ;
2770+
2771+ assert_eq ! (
2772+ get_filtered_join_mask(
2773+ LeftSemi ,
2774+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2775+ & BooleanArray :: from( vec![ false , true ] )
2776+ ) ,
2777+ Some ( ( BooleanArray :: from( vec![ false , true ] ) , vec![ 1 ] ) )
2778+ ) ;
2779+
2780+ assert_eq ! (
2781+ get_filtered_join_mask(
2782+ LeftSemi ,
2783+ UInt64Array :: from( vec![ 0 , 1 ] ) ,
2784+ & BooleanArray :: from( vec![ true , false ] )
2785+ ) ,
2786+ Some ( ( BooleanArray :: from( vec![ true , false ] ) , vec![ 0 ] ) )
2787+ ) ;
2788+
2789+ assert_eq ! (
2790+ get_filtered_join_mask(
2791+ LeftSemi ,
2792+ UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2793+ & BooleanArray :: from( vec![ false , true , true , true , true , true ] )
2794+ ) ,
2795+ Some ( (
2796+ BooleanArray :: from( vec![ false , true , false , true , false , false ] ) ,
2797+ vec![ 0 , 1 ]
2798+ ) )
2799+ ) ;
2800+
2801+ assert_eq ! (
2802+ get_filtered_join_mask(
2803+ LeftSemi ,
2804+ UInt64Array :: from( vec![ 0 , 0 , 0 , 1 , 1 , 1 ] ) ,
2805+ & BooleanArray :: from( vec![ false , false , false , false , false , true ] )
2806+ ) ,
2807+ Some ( (
2808+ BooleanArray :: from( vec![ false , false , false , false , false , true ] ) ,
2809+ vec![ 1 ]
2810+ ) )
2811+ ) ;
2812+
2813+ Ok ( ( ) )
2814+ }
2815+
26442816 /// Returns the column names on the schema
26452817 fn columns ( schema : & Schema ) -> Vec < String > {
26462818 schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( )
0 commit comments