@@ -78,8 +78,26 @@ use arrow_buffer::BooleanBuffer;
7878use datafusion_expr:: Operator ;
7979use datafusion_physical_expr_common:: datum:: compare_op_for_nested;
8080use futures:: { ready, Stream , StreamExt , TryStreamExt } ;
81+ use log:: debug;
8182use parking_lot:: Mutex ;
8283
84+ pub const RANDOM_STATE : RandomState = RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ;
85+
86+ #[ derive( Default ) ]
87+ pub struct JoinContext {
88+ build_state : Mutex < Option < Arc < JoinLeftData > > > ,
89+ }
90+
91+ impl JoinContext {
92+ pub fn set_build_state ( & self , state : Arc < JoinLeftData > ) {
93+ self . build_state . lock ( ) . replace ( state) ;
94+ }
95+
96+ pub fn get_build_state ( & self ) -> Option < Arc < JoinLeftData > > {
97+ self . build_state . lock ( ) . clone ( )
98+ }
99+ }
100+
83101pub struct SharedJoinState {
84102 state_impl : Arc < dyn SharedJoinStateImpl > ,
85103}
@@ -129,7 +147,7 @@ pub trait SharedJoinStateImpl: Send + Sync + 'static {
129147type SharedBitmapBuilder = Mutex < BooleanBufferBuilder > ;
130148
131149/// HashTable and input data for the left (build side) of a join
132- struct JoinLeftData {
150+ pub struct JoinLeftData {
133151 /// The hash table with indices into `batch`
134152 hash_map : JoinHashMap ,
135153 /// The input rows for the build side
@@ -167,6 +185,10 @@ impl JoinLeftData {
167185 }
168186 }
169187
188+ pub fn contains_hash ( & self , hash : u64 ) -> bool {
189+ self . hash_map . contains_hash ( hash)
190+ }
191+
170192 /// return a reference to the hash map
171193 fn hash_map ( & self ) -> & JoinHashMap {
172194 & self . hash_map
@@ -787,6 +809,7 @@ impl ExecutionPlan for HashJoinExec {
787809
788810 let distributed_state =
789811 context. session_config ( ) . get_extension :: < SharedJoinState > ( ) ;
812+ let join_context = context. session_config ( ) . get_extension :: < JoinContext > ( ) ;
790813
791814 let join_metrics = BuildProbeJoinMetrics :: new ( partition, & self . metrics ) ;
792815 let left_fut = match self . mode {
@@ -874,6 +897,7 @@ impl ExecutionPlan for HashJoinExec {
874897 batch_size,
875898 hashes_buffer : vec ! [ ] ,
876899 right_side_ordered : self . right . output_ordering ( ) . is_some ( ) ,
900+ join_context,
877901 } ) )
878902 }
879903
@@ -1199,6 +1223,7 @@ struct HashJoinStream {
11991223 hashes_buffer : Vec < u64 > ,
12001224 /// Specifies whether the right side has an ordering to potentially preserve
12011225 right_side_ordered : bool ,
1226+ join_context : Option < Arc < JoinContext > > ,
12021227}
12031228
12041229impl RecordBatchStream for HashJoinStream {
@@ -1411,6 +1436,11 @@ impl HashJoinStream {
14111436 . get_shared( cx) ) ?;
14121437 build_timer. done ( ) ;
14131438
1439+ if let Some ( ctx) = self . join_context . as_ref ( ) {
1440+ debug ! ( "setting join left data in join context" ) ;
1441+ ctx. set_build_state ( Arc :: clone ( & left_data) ) ;
1442+ }
1443+
14141444 self . state = HashJoinStreamState :: FetchProbeBatch ;
14151445 self . build_side = BuildSide :: Ready ( BuildSideReadyState { left_data } ) ;
14161446
0 commit comments