Skip to content

Commit 83c1026

Browse files
authored
fix: account for memory in RepartitionExec (#4820)
* refactor: explicit loop instead of (tail) recursion * test: simplify * fix: account for memory in `RepartitionExec` Fixes #4816. * fix: sorting memory limit test
1 parent 2db3d2e commit 83c1026

File tree

3 files changed

+116
-36
lines changed

3 files changed

+116
-36
lines changed

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ mod tests {
746746
use crate::{assert_batches_sorted_eq, physical_plan::common};
747747
use arrow::array::{Float64Array, UInt32Array};
748748
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
749-
use arrow::error::{ArrowError, Result as ArrowResult};
749+
use arrow::error::Result as ArrowResult;
750750
use arrow::record_batch::RecordBatch;
751751
use datafusion_common::{DataFusionError, Result, ScalarValue};
752752
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
@@ -1207,18 +1207,11 @@ mod tests {
12071207
let err = common::collect(stream).await.unwrap_err();
12081208

12091209
// error root cause traversal is a bit complicated, see #4172.
1210-
if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
1211-
if let Some(err) = err.downcast_ref::<DataFusionError>() {
1212-
assert!(
1213-
matches!(err, DataFusionError::ResourcesExhausted(_)),
1214-
"Wrong inner error type: {err}",
1215-
);
1216-
} else {
1217-
panic!("Wrong arrow error type: {err}")
1218-
}
1219-
} else {
1220-
panic!("Wrong outer error type: {err}")
1221-
}
1210+
let err = err.find_root();
1211+
assert!(
1212+
matches!(err, DataFusionError::ResourcesExhausted(_)),
1213+
"Wrong error type: {err}",
1214+
);
12221215
}
12231216

12241217
Ok(())

datafusion/core/src/physical_plan/repartition.rs

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::task::{Context, Poll};
2424
use std::{any::Any, vec};
2525

2626
use crate::error::{DataFusionError, Result};
27+
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
2728
use crate::physical_plan::hash_utils::create_hashes;
2829
use crate::physical_plan::{
2930
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
@@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
5051
use tokio::task::JoinHandle;
5152

5253
type MaybeBatch = Option<ArrowResult<RecordBatch>>;
54+
type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
5355

5456
/// Inner state of [`RepartitionExec`].
5557
#[derive(Debug)]
5658
struct RepartitionExecState {
5759
/// Channels for sending batches from input partitions to output partitions.
5860
/// Key is the partition number.
59-
channels:
60-
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
61+
channels: HashMap<
62+
usize,
63+
(
64+
UnboundedSender<MaybeBatch>,
65+
UnboundedReceiver<MaybeBatch>,
66+
SharedMemoryReservation,
67+
),
68+
>,
6169

6270
/// Helper that ensures that that background job is killed once it is no longer needed.
6371
abort_helper: Arc<AbortOnDropMany<()>>,
@@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
338346
// for this would be to add spill-to-disk capabilities.
339347
let (sender, receiver) =
340348
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
341-
state.channels.insert(partition, (sender, receiver));
349+
let reservation = Arc::new(Mutex::new(
350+
MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
351+
.register(context.memory_pool()),
352+
));
353+
state
354+
.channels
355+
.insert(partition, (sender, receiver, reservation));
342356
}
343357

344358
// launch one async task per *input* partition
@@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
347361
let txs: HashMap<_, _> = state
348362
.channels
349363
.iter()
350-
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
364+
.map(|(partition, (tx, _rx, reservation))| {
365+
(*partition, (tx.clone(), Arc::clone(reservation)))
366+
})
351367
.collect();
352368

353369
let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
@@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
366382
// (and pass along any errors, including panic!s)
367383
let join_handle = tokio::spawn(Self::wait_for_task(
368384
AbortOnDropSingle::new(input_task),
369-
txs,
385+
txs.into_iter()
386+
.map(|(partition, (tx, _reservation))| (partition, tx))
387+
.collect(),
370388
));
371389
join_handles.push(join_handle);
372390
}
@@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {
381399

382400
// now return stream for the specified *output* partition which will
383401
// read from the channel
402+
let (_tx, rx, reservation) = state
403+
.channels
404+
.remove(&partition)
405+
.expect("partition not used yet");
384406
Ok(Box::pin(RepartitionStream {
385407
num_input_partitions,
386408
num_input_partitions_processed: 0,
387409
schema: self.input.schema(),
388-
input: UnboundedReceiverStream::new(
389-
state.channels.remove(&partition).unwrap().1,
390-
),
410+
input: UnboundedReceiverStream::new(rx),
391411
drop_helper: Arc::clone(&state.abort_helper),
412+
reservation,
392413
}))
393414
}
394415

@@ -439,7 +460,7 @@ impl RepartitionExec {
439460
async fn pull_from_input(
440461
input: Arc<dyn ExecutionPlan>,
441462
i: usize,
442-
mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
463+
mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>,
443464
partitioning: Partitioning,
444465
r_metrics: RepartitionMetrics,
445466
context: Arc<TaskContext>,
@@ -467,11 +488,16 @@ impl RepartitionExec {
467488
};
468489

469490
partitioner.partition(batch, |partition, partitioned| {
491+
let size = partitioned.get_array_memory_size();
492+
470493
let timer = r_metrics.send_time.timer();
471494
// if there is still a receiver, send to it
472-
if let Some(tx) = txs.get_mut(&partition) {
495+
if let Some((tx, reservation)) = txs.get_mut(&partition) {
496+
reservation.lock().try_grow(size)?;
497+
473498
if tx.send(Some(Ok(partitioned))).is_err() {
474499
// If the other end has hung up, it was an early shutdown (e.g. LIMIT)
500+
reservation.lock().shrink(size);
475501
txs.remove(&partition);
476502
}
477503
}
@@ -546,6 +572,9 @@ struct RepartitionStream {
546572
/// Handle to ensure background tasks are killed when no longer needed.
547573
#[allow(dead_code)]
548574
drop_helper: Arc<AbortOnDropMany<()>>,
575+
576+
/// Memory reservation.
577+
reservation: SharedMemoryReservation,
549578
}
550579

551580
impl Stream for RepartitionStream {
@@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
555584
mut self: Pin<&mut Self>,
556585
cx: &mut Context<'_>,
557586
) -> Poll<Option<Self::Item>> {
558-
match self.input.poll_next_unpin(cx) {
559-
Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
560-
Poll::Ready(Some(None)) => {
561-
self.num_input_partitions_processed += 1;
562-
if self.num_input_partitions == self.num_input_partitions_processed {
563-
// all input partitions have finished sending batches
564-
Poll::Ready(None)
565-
} else {
566-
// other partitions still have data to send
567-
self.poll_next(cx)
587+
loop {
588+
match self.input.poll_next_unpin(cx) {
589+
Poll::Ready(Some(Some(v))) => {
590+
if let Ok(batch) = &v {
591+
self.reservation
592+
.lock()
593+
.shrink(batch.get_array_memory_size());
594+
}
595+
596+
return Poll::Ready(Some(v));
597+
}
598+
Poll::Ready(Some(None)) => {
599+
self.num_input_partitions_processed += 1;
600+
601+
if self.num_input_partitions == self.num_input_partitions_processed {
602+
// all input partitions have finished sending batches
603+
return Poll::Ready(None);
604+
} else {
605+
// other partitions still have data to send
606+
continue;
607+
}
608+
}
609+
Poll::Ready(None) => {
610+
return Poll::Ready(None);
611+
}
612+
Poll::Pending => {
613+
return Poll::Pending;
568614
}
569615
}
570-
Poll::Ready(None) => Poll::Ready(None),
571-
Poll::Pending => Poll::Pending,
572616
}
573617
}
574618
}
@@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
583627
#[cfg(test)]
584628
mod tests {
585629
use super::*;
630+
use crate::execution::context::SessionConfig;
631+
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
586632
use crate::from_slice::FromSlice;
587633
use crate::prelude::SessionContext;
588634
use crate::test::create_vec_batches;
@@ -1078,4 +1124,41 @@ mod tests {
10781124
assert!(batch0.is_empty() || batch1.is_empty());
10791125
Ok(())
10801126
}
1127+
1128+
#[tokio::test]
1129+
async fn oom() -> Result<()> {
1130+
// define input partitions
1131+
let schema = test_schema();
1132+
let partition = create_vec_batches(&schema, 50);
1133+
let input_partitions = vec![partition];
1134+
let partitioning = Partitioning::RoundRobinBatch(4);
1135+
1136+
// setup up context
1137+
let session_ctx = SessionContext::with_config_rt(
1138+
SessionConfig::default(),
1139+
Arc::new(
1140+
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
1141+
.unwrap(),
1142+
),
1143+
);
1144+
let task_ctx = session_ctx.task_ctx();
1145+
1146+
// create physical plan
1147+
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
1148+
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
1149+
1150+
// pull partitions
1151+
for i in 0..exec.partitioning.partition_count() {
1152+
let mut stream = exec.execute(i, task_ctx.clone())?;
1153+
let err =
1154+
DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err());
1155+
let err = err.find_root();
1156+
assert!(
1157+
matches!(err, DataFusionError::ResourcesExhausted(_)),
1158+
"Wrong error type: {err}",
1159+
);
1160+
}
1161+
1162+
Ok(())
1163+
}
10811164
}

datafusion/core/tests/memory_limit.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize)
9595

9696
let runtime = RuntimeEnv::new(rt_config).unwrap();
9797

98-
let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
98+
let ctx = SessionContext::with_config_rt(
99+
// do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first)
100+
SessionConfig::new().with_target_partitions(1),
101+
Arc::new(runtime),
102+
);
99103
ctx.register_table("t", Arc::new(table))
100104
.expect("registering table");
101105

0 commit comments

Comments
 (0)