Skip to content

Commit ea3b965

Browse files
authored
Memory reservation & metrics for cross join (#5339)
* Memory reservation & metrics for cross join * memory_limit test & removed fixed error msg from test_overallocation
1 parent 20d08ab commit ea3b965

File tree

6 files changed

+246
-60
lines changed

6 files changed

+246
-60
lines changed

datafusion/core/src/physical_plan/common.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
use super::{RecordBatchStream, SendableRecordBatchStream};
2121
use crate::error::{DataFusionError, Result};
2222
use crate::execution::context::TaskContext;
23+
use crate::execution::memory_pool::MemoryReservation;
2324
use crate::physical_plan::metrics::MemTrackingMetrics;
2425
use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics};
2526
use arrow::datatypes::{Schema, SchemaRef};
@@ -28,6 +29,7 @@ use arrow::record_batch::RecordBatch;
2829
use datafusion_physical_expr::PhysicalSortExpr;
2930
use futures::{Future, Stream, StreamExt, TryStreamExt};
3031
use log::debug;
32+
use parking_lot::Mutex;
3133
use pin_project_lite::pin_project;
3234
use std::fs;
3335
use std::fs::{metadata, File};
@@ -37,6 +39,8 @@ use std::task::{Context, Poll};
3739
use tokio::sync::mpsc;
3840
use tokio::task::JoinHandle;
3941

42+
pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
43+
4044
/// Stream of record batches
4145
pub struct SizedRecordBatchStream {
4246
schema: SchemaRef,

datafusion/core/src/physical_plan/joins/cross_join.rs

Lines changed: 153 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ use arrow::datatypes::{Schema, SchemaRef};
2626
use arrow::record_batch::RecordBatch;
2727

2828
use crate::execution::context::TaskContext;
29+
use crate::execution::memory_pool::MemoryConsumer;
30+
use crate::physical_plan::common::SharedMemoryReservation;
31+
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
2932
use crate::physical_plan::{
3033
coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec,
3134
ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties,
@@ -35,12 +38,11 @@ use crate::physical_plan::{
3538
use crate::{error::Result, scalar::ScalarValue};
3639
use async_trait::async_trait;
3740
use datafusion_common::DataFusionError;
38-
use log::debug;
39-
use std::time::Instant;
41+
use parking_lot::Mutex;
4042

4143
use super::utils::{
42-
adjust_right_output_partitioning, cross_join_equivalence_properties, OnceAsync,
43-
OnceFut,
44+
adjust_right_output_partitioning, cross_join_equivalence_properties,
45+
BuildProbeJoinMetrics, OnceAsync, OnceFut,
4446
};
4547

4648
/// Data of the left side
@@ -58,6 +60,8 @@ pub struct CrossJoinExec {
5860
schema: SchemaRef,
5961
/// Build-side data
6062
left_fut: OnceAsync<JoinLeftData>,
63+
/// Execution plan metrics
64+
metrics: ExecutionPlanMetricsSet,
6165
}
6266

6367
impl CrossJoinExec {
@@ -79,6 +83,7 @@ impl CrossJoinExec {
7983
right,
8084
schema,
8185
left_fut: Default::default(),
86+
metrics: ExecutionPlanMetricsSet::default(),
8287
}
8388
}
8489

@@ -97,9 +102,9 @@ impl CrossJoinExec {
97102
async fn load_left_input(
98103
left: Arc<dyn ExecutionPlan>,
99104
context: Arc<TaskContext>,
105+
metrics: BuildProbeJoinMetrics,
106+
reservation: SharedMemoryReservation,
100107
) -> Result<JoinLeftData> {
101-
let start = Instant::now();
102-
103108
// merge all left parts into a single stream
104109
let merge = {
105110
if left.output_partitioning().partition_count() != 1 {
@@ -111,22 +116,28 @@ async fn load_left_input(
111116
let stream = merge.execute(0, context)?;
112117

113118
// Load all batches and count the rows
114-
let (batches, num_rows) = stream
115-
.try_fold((Vec::new(), 0usize), |mut acc, batch| async {
116-
acc.1 += batch.num_rows();
117-
acc.0.push(batch);
118-
Ok(acc)
119-
})
119+
let (batches, num_rows, _, _) = stream
120+
.try_fold(
121+
(Vec::new(), 0usize, metrics, reservation),
122+
|mut acc, batch| async {
123+
let batch_size = batch.get_array_memory_size();
124+
// Reserve memory for incoming batch
125+
acc.3.lock().try_grow(batch_size)?;
126+
// Update metrics
127+
acc.2.build_mem_used.add(batch_size);
128+
acc.2.build_input_batches.add(1);
129+
acc.2.build_input_rows.add(batch.num_rows());
130+
// Update rowcount
131+
acc.1 += batch.num_rows();
132+
// Push batch to output
133+
acc.0.push(batch);
134+
Ok(acc)
135+
},
136+
)
120137
.await?;
121138

122139
let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;
123140

124-
debug!(
125-
"Built build-side of cross join containing {} rows in {} ms",
126-
num_rows,
127-
start.elapsed().as_millis()
128-
);
129-
130141
Ok(merged_batch)
131142
}
132143

@@ -143,6 +154,10 @@ impl ExecutionPlan for CrossJoinExec {
143154
vec![self.left.clone(), self.right.clone()]
144155
}
145156

157+
fn metrics(&self) -> Option<MetricsSet> {
158+
Some(self.metrics.clone_inner())
159+
}
160+
146161
/// Specifies whether this plan generates an infinite stream of records.
147162
/// If the plan does not support pipelining, but it its input(s) are
148163
/// infinite, returns an error to indicate this.
@@ -205,21 +220,29 @@ impl ExecutionPlan for CrossJoinExec {
205220
) -> Result<SendableRecordBatchStream> {
206221
let stream = self.right.execute(partition, context.clone())?;
207222

208-
let left_fut = self
209-
.left_fut
210-
.once(|| load_left_input(self.left.clone(), context));
223+
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
224+
let reservation = Arc::new(Mutex::new(
225+
MemoryConsumer::new(format!("CrossJoinStream[{partition}]"))
226+
.register(context.memory_pool()),
227+
));
228+
229+
let left_fut = self.left_fut.once(|| {
230+
load_left_input(
231+
self.left.clone(),
232+
context,
233+
join_metrics.clone(),
234+
reservation.clone(),
235+
)
236+
});
211237

212238
Ok(Box::pin(CrossJoinStream {
213239
schema: self.schema.clone(),
214240
left_fut,
215241
right: stream,
216242
right_batch: Arc::new(parking_lot::Mutex::new(None)),
217243
left_index: 0,
218-
num_input_batches: 0,
219-
num_input_rows: 0,
220-
num_output_batches: 0,
221-
num_output_rows: 0,
222-
join_time: 0,
244+
join_metrics,
245+
reservation,
223246
}))
224247
}
225248

@@ -321,16 +344,10 @@ struct CrossJoinStream {
321344
left_index: usize,
322345
/// Current batch being processed from the right side
323346
right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
324-
/// number of input batches
325-
num_input_batches: usize,
326-
/// number of input rows
327-
num_input_rows: usize,
328-
/// number of batches produced
329-
num_output_batches: usize,
330-
/// number of rows produced
331-
num_output_rows: usize,
332-
/// total time for joining probe-side batches to the build-side batches
333-
join_time: usize,
347+
/// join execution metrics
348+
join_metrics: BuildProbeJoinMetrics,
349+
/// memory reservation
350+
reservation: SharedMemoryReservation,
334351
}
335352

336353
impl RecordBatchStream for CrossJoinStream {
@@ -385,28 +402,30 @@ impl CrossJoinStream {
385402
&mut self,
386403
cx: &mut std::task::Context<'_>,
387404
) -> std::task::Poll<Option<Result<RecordBatch>>> {
405+
let build_timer = self.join_metrics.build_time.timer();
388406
let left_data = match ready!(self.left_fut.get(cx)) {
389407
Ok(left_data) => left_data,
390408
Err(e) => return Poll::Ready(Some(Err(e))),
391409
};
410+
build_timer.done();
392411

393412
if left_data.num_rows() == 0 {
394413
return Poll::Ready(None);
395414
}
396415

397416
if self.left_index > 0 && self.left_index < left_data.num_rows() {
398-
let start = Instant::now();
417+
let join_timer = self.join_metrics.join_time.timer();
399418
let right_batch = {
400419
let right_batch = self.right_batch.lock();
401420
right_batch.clone().unwrap()
402421
};
403422
let result =
404423
build_batch(self.left_index, &right_batch, left_data, &self.schema);
405-
self.num_input_rows += right_batch.num_rows();
424+
self.join_metrics.input_rows.add(right_batch.num_rows());
406425
if let Ok(ref batch) = result {
407-
self.join_time += start.elapsed().as_millis() as usize;
408-
self.num_output_batches += 1;
409-
self.num_output_rows += batch.num_rows();
426+
join_timer.done();
427+
self.join_metrics.output_batches.add(1);
428+
self.join_metrics.output_rows.add(batch.num_rows());
410429
}
411430
self.left_index += 1;
412431
return Poll::Ready(Some(result));
@@ -416,15 +435,15 @@ impl CrossJoinStream {
416435
.poll_next_unpin(cx)
417436
.map(|maybe_batch| match maybe_batch {
418437
Some(Ok(batch)) => {
419-
let start = Instant::now();
438+
let join_timer = self.join_metrics.join_time.timer();
420439
let result =
421440
build_batch(self.left_index, &batch, left_data, &self.schema);
422-
self.num_input_batches += 1;
423-
self.num_input_rows += batch.num_rows();
441+
self.join_metrics.input_batches.add(1);
442+
self.join_metrics.input_rows.add(batch.num_rows());
424443
if let Ok(ref batch) = result {
425-
self.join_time += start.elapsed().as_millis() as usize;
426-
self.num_output_batches += 1;
427-
self.num_output_rows += batch.num_rows();
444+
join_timer.done();
445+
self.join_metrics.output_batches.add(1);
446+
self.join_metrics.output_rows.add(batch.num_rows());
428447
}
429448
self.left_index = 1;
430449

@@ -434,15 +453,7 @@ impl CrossJoinStream {
434453
Some(result)
435454
}
436455
other => {
437-
debug!(
438-
"Processed {} probe-side input batches containing {} rows and \
439-
produced {} output batches containing {} rows in {} ms",
440-
self.num_input_batches,
441-
self.num_input_rows,
442-
self.num_output_batches,
443-
self.num_output_rows,
444-
self.join_time
445-
);
456+
self.reservation.lock().free();
446457
other
447458
}
448459
})
@@ -452,6 +463,26 @@ impl CrossJoinStream {
452463
#[cfg(test)]
453464
mod tests {
454465
use super::*;
466+
use crate::assert_batches_sorted_eq;
467+
use crate::common::assert_contains;
468+
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
469+
use crate::physical_plan::common;
470+
use crate::prelude::{SessionConfig, SessionContext};
471+
use crate::test::{build_table_scan_i32, columns};
472+
473+
async fn join_collect(
474+
left: Arc<dyn ExecutionPlan>,
475+
right: Arc<dyn ExecutionPlan>,
476+
context: Arc<TaskContext>,
477+
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
478+
let join = CrossJoinExec::new(left, right);
479+
let columns_header = columns(&join.schema());
480+
481+
let stream = join.execute(0, context)?;
482+
let batches = common::collect(stream).await?;
483+
484+
Ok((columns_header, batches))
485+
}
455486

456487
#[tokio::test]
457488
async fn test_stats_cartesian_product() {
@@ -589,4 +620,70 @@ mod tests {
589620

590621
assert_eq!(result, expected);
591622
}
623+
624+
#[tokio::test]
625+
async fn test_join() -> Result<()> {
626+
let session_ctx = SessionContext::new();
627+
let task_ctx = session_ctx.task_ctx();
628+
629+
let left = build_table_scan_i32(
630+
("a1", &vec![1, 2, 3]),
631+
("b1", &vec![4, 5, 6]),
632+
("c1", &vec![7, 8, 9]),
633+
);
634+
let right = build_table_scan_i32(
635+
("a2", &vec![10, 11]),
636+
("b2", &vec![12, 13]),
637+
("c2", &vec![14, 15]),
638+
);
639+
640+
let (columns, batches) = join_collect(left, right, task_ctx).await?;
641+
642+
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
643+
let expected = vec![
644+
"+----+----+----+----+----+----+",
645+
"| a1 | b1 | c1 | a2 | b2 | c2 |",
646+
"+----+----+----+----+----+----+",
647+
"| 1 | 4 | 7 | 10 | 12 | 14 |",
648+
"| 1 | 4 | 7 | 11 | 13 | 15 |",
649+
"| 2 | 5 | 8 | 10 | 12 | 14 |",
650+
"| 2 | 5 | 8 | 11 | 13 | 15 |",
651+
"| 3 | 6 | 9 | 10 | 12 | 14 |",
652+
"| 3 | 6 | 9 | 11 | 13 | 15 |",
653+
"+----+----+----+----+----+----+",
654+
];
655+
656+
assert_batches_sorted_eq!(expected, &batches);
657+
658+
Ok(())
659+
}
660+
661+
#[tokio::test]
662+
async fn test_overallocation() -> Result<()> {
663+
let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0);
664+
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
665+
let session_ctx =
666+
SessionContext::with_config_rt(SessionConfig::default(), runtime);
667+
let task_ctx = session_ctx.task_ctx();
668+
669+
let left = build_table_scan_i32(
670+
("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
671+
("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
672+
("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
673+
);
674+
let right = build_table_scan_i32(
675+
("a2", &vec![10, 11]),
676+
("b2", &vec![12, 13]),
677+
("c2", &vec![14, 15]),
678+
);
679+
680+
let err = join_collect(left, right, task_ctx).await.unwrap_err();
681+
682+
assert_contains!(
683+
err.to_string(),
684+
"External error: Resources exhausted: Failed to allocate additional"
685+
);
686+
687+
Ok(())
688+
}
592689
}

0 commit comments

Comments
 (0)