Skip to content

Commit da06eff

Browse files
committed
refactor(hash_join): Execute build side earlier
1 parent a0ba091 commit da06eff

File tree

1 file changed

+24
-26
lines changed

1 file changed

+24
-26
lines changed

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -800,34 +800,36 @@ impl ExecutionPlan for HashJoinExec {
800800

801801
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
802802
let left_fut = match self.mode {
803-
PartitionMode::CollectLeft => self.left_fut.once(|| {
804-
let reservation =
805-
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
803+
PartitionMode::CollectLeft => {
804+
let left = coalesce_partitions_if_needed(Arc::clone(&self.left));
805+
let left_stream = left.execute(0, Arc::clone(&context))?;
806806

807-
let left = coalesce_partitions_if_needed(self.left.clone());
808-
collect_left_input(
809-
0,
810-
self.random_state.clone(),
811-
left,
812-
on_left.clone(),
813-
Arc::clone(&context),
814-
join_metrics.clone(),
815-
reservation,
816-
need_produce_result_in_final(self.join_type),
817-
self.right().output_partitioning().partition_count(),
818-
)
819-
}),
807+
self.left_fut.once(|| {
808+
let reservation = MemoryConsumer::new("HashJoinInput")
809+
.register(context.memory_pool());
810+
811+
collect_left_input(
812+
self.random_state.clone(),
813+
left_stream,
814+
on_left.clone(),
815+
join_metrics.clone(),
816+
reservation,
817+
need_produce_result_in_final(self.join_type),
818+
self.right().output_partitioning().partition_count(),
819+
)
820+
})
821+
}
820822
PartitionMode::Partitioned => {
823+
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
824+
821825
let reservation =
822826
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
823827
.register(context.memory_pool());
824828

825829
OnceFut::new(collect_left_input(
826-
partition,
827830
self.random_state.clone(),
828-
Arc::clone(&self.left),
831+
left_stream,
829832
on_left.clone(),
830-
Arc::clone(&context),
831833
join_metrics.clone(),
832834
reservation,
833835
need_produce_result_in_final(self.join_type),
@@ -943,25 +945,21 @@ fn coalesce_partitions_if_needed(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn Execut
943945
/// hash table (`LeftJoinData`)
944946
#[allow(clippy::too_many_arguments)]
945947
async fn collect_left_input(
946-
partition: usize,
947948
random_state: RandomState,
948-
left: Arc<dyn ExecutionPlan>,
949+
left_stream: SendableRecordBatchStream,
949950
on_left: Vec<PhysicalExprRef>,
950-
context: Arc<TaskContext>,
951951
metrics: BuildProbeJoinMetrics,
952952
reservation: MemoryReservation,
953953
with_visited_indices_bitmap: bool,
954954
probe_threads_count: usize,
955955
) -> Result<JoinLeftData> {
956-
let schema = left.schema();
957-
958-
let stream = left.execute(partition, Arc::clone(&context))?;
956+
let schema = left_stream.schema();
959957

960958
// This operation performs 2 steps at once:
961959
// 1. creates a [JoinHashMap] of all batches from the stream
962960
// 2. stores the batches in a vector.
963961
let initial = (Vec::new(), 0, metrics, reservation);
964-
let (batches, num_rows, metrics, mut reservation) = stream
962+
let (batches, num_rows, metrics, mut reservation) = left_stream
965963
.try_fold(initial, |mut acc, batch| async {
966964
let batch_size = get_record_batch_memory_size(&batch);
967965
// Reserve memory for incoming batch

0 commit comments

Comments
 (0)