Skip to content

Commit 63ce486

Browse files
author
mertak-synnada
authored
Chore: Do not return empty record batches from streams (#13794)
* do not emit empty record batches in plans * change function signatures to Option<RecordBatch> if empty batches are possible * format code * shorten code * change list_unnest_at_level for returning Option value * add documentation take concat_batches into compute_aggregates function again * create unit test for row_hash.rs * add test for unnest * add test for unnest * add test for partial sort * add test for bounded window agg * add test for window agg * apply simplifications and fix typo * apply simplifications and fix typo
1 parent 7e0fc14 commit 63ce486

File tree

9 files changed

+256
-111
lines changed

9 files changed

+256
-111
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,6 +2380,30 @@ mod tests {
23802380
Ok(())
23812381
}
23822382

2383+
#[tokio::test]
2384+
async fn aggregate_assert_no_empty_batches() -> Result<()> {
2385+
// build plan using DataFrame API
2386+
let df = test_table().await?;
2387+
let group_expr = vec![col("c1")];
2388+
let aggr_expr = vec![
2389+
min(col("c12")),
2390+
max(col("c12")),
2391+
avg(col("c12")),
2392+
sum(col("c12")),
2393+
count(col("c12")),
2394+
count_distinct(col("c12")),
2395+
median(col("c12")),
2396+
];
2397+
2398+
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
2399+
// Empty batches should not be produced
2400+
for batch in df {
2401+
assert!(batch.num_rows() > 0);
2402+
}
2403+
2404+
Ok(())
2405+
}
2406+
23832407
#[tokio::test]
23842408
async fn test_aggregate_with_pk() -> Result<()> {
23852409
// create the dataframe

datafusion/core/tests/dataframe/mod.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,43 @@ async fn unnest_aggregate_columns() -> Result<()> {
12461246
Ok(())
12471247
}
12481248

1249+
#[tokio::test]
1250+
async fn unnest_no_empty_batches() -> Result<()> {
1251+
let mut shape_id_builder = UInt32Builder::new();
1252+
let mut tag_id_builder = UInt32Builder::new();
1253+
1254+
for shape_id in 1..=10 {
1255+
for tag_id in 1..=10 {
1256+
shape_id_builder.append_value(shape_id as u32);
1257+
tag_id_builder.append_value((shape_id * 10 + tag_id) as u32);
1258+
}
1259+
}
1260+
1261+
let batch = RecordBatch::try_from_iter(vec![
1262+
("shape_id", Arc::new(shape_id_builder.finish()) as ArrayRef),
1263+
("tag_id", Arc::new(tag_id_builder.finish()) as ArrayRef),
1264+
])?;
1265+
1266+
let ctx = SessionContext::new();
1267+
ctx.register_batch("shapes", batch)?;
1268+
let df = ctx.table("shapes").await?;
1269+
1270+
let results = df
1271+
.clone()
1272+
.aggregate(
1273+
vec![col("shape_id")],
1274+
vec![array_agg(col("tag_id")).alias("tag_id")],
1275+
)?
1276+
.collect()
1277+
.await?;
1278+
1279+
// Assert that there are no empty batches in result
1280+
for rb in results {
1281+
assert!(rb.num_rows() > 0);
1282+
}
1283+
Ok(())
1284+
}
1285+
12491286
#[tokio::test]
12501287
async fn unnest_array_agg() -> Result<()> {
12511288
let mut shape_id_builder = UInt32Builder::new();
@@ -1268,6 +1305,12 @@ async fn unnest_array_agg() -> Result<()> {
12681305
let df = ctx.table("shapes").await?;
12691306

12701307
let results = df.clone().collect().await?;
1308+
1309+
// Assert that there are no empty batches in result
1310+
for rb in results.clone() {
1311+
assert!(rb.num_rows() > 0);
1312+
}
1313+
12711314
let expected = vec![
12721315
"+----------+--------+",
12731316
"| shape_id | tag_id |",

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
299299
Linear,
300300
)?);
301301
let task_ctx = ctx.task_ctx();
302-
let mut collected_results =
303-
collect(running_window_exec, task_ctx).await?;
304-
collected_results.retain(|batch| batch.num_rows() > 0);
302+
let collected_results = collect(running_window_exec, task_ctx).await?;
305303
let input_batch_sizes = batches
306304
.iter()
307305
.map(|batch| batch.num_rows())
@@ -310,6 +308,8 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
310308
.iter()
311309
.map(|batch| batch.num_rows())
312310
.collect::<Vec<_>>();
311+
// There should be no empty batches at results
312+
assert!(result_batch_sizes.iter().all(|e| *e > 0));
313313
if causal {
314314
// For causal window frames, we can generate results immediately
315315
// for each input batch. Hence, batch sizes should match.
@@ -688,8 +688,8 @@ async fn run_window_test(
688688
let collected_running = collect(running_window_exec, task_ctx)
689689
.await?
690690
.into_iter()
691-
.filter(|b| b.num_rows() > 0)
692691
.collect::<Vec<_>>();
692+
assert!(collected_running.iter().all(|rb| rb.num_rows() > 0));
693693

694694
// BoundedWindowAggExec should produce more chunk than the usual WindowAggExec.
695695
// Otherwise it means that we cannot generate result in running mode.

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,13 @@ impl Stream for GroupedHashAggregateStream {
654654
}
655655

656656
if let Some(to_emit) = self.group_ordering.emit_to() {
657-
let batch = extract_ok!(self.emit(to_emit, false));
658-
self.exec_state = ExecutionState::ProducingOutput(batch);
659657
timer.done();
658+
if let Some(batch) =
659+
extract_ok!(self.emit(to_emit, false))
660+
{
661+
self.exec_state =
662+
ExecutionState::ProducingOutput(batch);
663+
};
660664
// make sure the exec_state just set is not overwritten below
661665
break 'reading_input;
662666
}
@@ -693,9 +697,13 @@ impl Stream for GroupedHashAggregateStream {
693697
}
694698

695699
if let Some(to_emit) = self.group_ordering.emit_to() {
696-
let batch = extract_ok!(self.emit(to_emit, false));
697-
self.exec_state = ExecutionState::ProducingOutput(batch);
698700
timer.done();
701+
if let Some(batch) =
702+
extract_ok!(self.emit(to_emit, false))
703+
{
704+
self.exec_state =
705+
ExecutionState::ProducingOutput(batch);
706+
};
699707
// make sure the exec_state just set is not overwritten below
700708
break 'reading_input;
701709
}
@@ -768,6 +776,9 @@ impl Stream for GroupedHashAggregateStream {
768776
let output = batch.slice(0, size);
769777
(ExecutionState::ProducingOutput(remaining), output)
770778
};
779+
// Empty record batches should not be emitted.
780+
// They need to be treated as [`Option<RecordBatch>`]es and handled separately
781+
debug_assert!(output_batch.num_rows() > 0);
771782
return Poll::Ready(Some(Ok(
772783
output_batch.record_output(&self.baseline_metrics)
773784
)));
@@ -902,14 +913,14 @@ impl GroupedHashAggregateStream {
902913

903914
/// Create an output RecordBatch with the group keys and
904915
/// accumulator states/values specified in emit_to
905-
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
916+
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
906917
let schema = if spilling {
907918
Arc::clone(&self.spill_state.spill_schema)
908919
} else {
909920
self.schema()
910921
};
911922
if self.group_values.is_empty() {
912-
return Ok(RecordBatch::new_empty(schema));
923+
return Ok(None);
913924
}
914925

915926
let mut output = self.group_values.emit(emit_to)?;
@@ -937,7 +948,8 @@ impl GroupedHashAggregateStream {
937948
// over the target memory size after emission, we can emit again rather than returning Err.
938949
let _ = self.update_memory_reservation();
939950
let batch = RecordBatch::try_new(schema, output)?;
940-
Ok(batch)
951+
debug_assert!(batch.num_rows() > 0);
952+
Ok(Some(batch))
941953
}
942954

943955
/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
@@ -963,7 +975,9 @@ impl GroupedHashAggregateStream {
963975

964976
/// Emit all rows, sort them, and store them on disk.
965977
fn spill(&mut self) -> Result<()> {
966-
let emit = self.emit(EmitTo::All, true)?;
978+
let Some(emit) = self.emit(EmitTo::All, true)? else {
979+
return Ok(());
980+
};
967981
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?;
968982
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
969983
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1008,8 +1022,9 @@ impl GroupedHashAggregateStream {
10081022
{
10091023
assert_eq!(self.mode, AggregateMode::Partial);
10101024
let n = self.group_values.len() / self.batch_size * self.batch_size;
1011-
let batch = self.emit(EmitTo::First(n), false)?;
1012-
self.exec_state = ExecutionState::ProducingOutput(batch);
1025+
if let Some(batch) = self.emit(EmitTo::First(n), false)? {
1026+
self.exec_state = ExecutionState::ProducingOutput(batch);
1027+
};
10131028
}
10141029
Ok(())
10151030
}
@@ -1019,7 +1034,9 @@ impl GroupedHashAggregateStream {
10191034
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
10201035
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
10211036
fn update_merged_stream(&mut self) -> Result<()> {
1022-
let batch = self.emit(EmitTo::All, true)?;
1037+
let Some(batch) = self.emit(EmitTo::All, true)? else {
1038+
return Ok(());
1039+
};
10231040
// clear up memory for streaming_merge
10241041
self.clear_all();
10251042
self.update_memory_reservation()?;
@@ -1067,7 +1084,7 @@ impl GroupedHashAggregateStream {
10671084
let timer = elapsed_compute.timer();
10681085
self.exec_state = if self.spill_state.spills.is_empty() {
10691086
let batch = self.emit(EmitTo::All, false)?;
1070-
ExecutionState::ProducingOutput(batch)
1087+
batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
10711088
} else {
10721089
// If spill files exist, stream-merge them.
10731090
self.update_merged_stream()?;
@@ -1096,8 +1113,9 @@ impl GroupedHashAggregateStream {
10961113
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
10971114
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
10981115
if probe.should_skip() {
1099-
let batch = self.emit(EmitTo::All, false)?;
1100-
self.exec_state = ExecutionState::ProducingOutput(batch);
1116+
if let Some(batch) = self.emit(EmitTo::All, false)? {
1117+
self.exec_state = ExecutionState::ProducingOutput(batch);
1118+
};
11011119
}
11021120
}
11031121

datafusion/physical-plan/src/sorts/partial_sort.rs

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -363,31 +363,31 @@ impl PartialSortStream {
363363
if self.is_closed {
364364
return Poll::Ready(None);
365365
}
366-
let result = match ready!(self.input.poll_next_unpin(cx)) {
367-
Some(Ok(batch)) => {
368-
if let Some(slice_point) =
369-
self.get_slice_point(self.common_prefix_length, &batch)?
370-
{
371-
self.in_mem_batches.push(batch.slice(0, slice_point));
372-
let remaining_batch =
373-
batch.slice(slice_point, batch.num_rows() - slice_point);
374-
let sorted_batch = self.sort_in_mem_batches();
375-
self.in_mem_batches.push(remaining_batch);
376-
sorted_batch
377-
} else {
378-
self.in_mem_batches.push(batch);
379-
Ok(RecordBatch::new_empty(self.schema()))
366+
loop {
367+
return Poll::Ready(Some(match ready!(self.input.poll_next_unpin(cx)) {
368+
Some(Ok(batch)) => {
369+
if let Some(slice_point) =
370+
self.get_slice_point(self.common_prefix_length, &batch)?
371+
{
372+
self.in_mem_batches.push(batch.slice(0, slice_point));
373+
let remaining_batch =
374+
batch.slice(slice_point, batch.num_rows() - slice_point);
375+
let sorted_batch = self.sort_in_mem_batches();
376+
self.in_mem_batches.push(remaining_batch);
377+
sorted_batch
378+
} else {
379+
self.in_mem_batches.push(batch);
380+
continue;
381+
}
380382
}
381-
}
382-
Some(Err(e)) => Err(e),
383-
None => {
384-
self.is_closed = true;
385-
// once input is consumed, sort the rest of the inserted batches
386-
self.sort_in_mem_batches()
387-
}
388-
};
389-
390-
Poll::Ready(Some(result))
383+
Some(Err(e)) => Err(e),
384+
None => {
385+
self.is_closed = true;
386+
// once input is consumed, sort the rest of the inserted batches
387+
self.sort_in_mem_batches()
388+
}
389+
}));
390+
}
391391
}
392392

393393
/// Returns a sorted RecordBatch from in_mem_batches and clears in_mem_batches
@@ -407,6 +407,9 @@ impl PartialSortStream {
407407
self.is_closed = true;
408408
}
409409
}
410+
// Empty record batches should not be emitted.
411+
// They need to be treated as [`Option<RecordBatch>`]es and handle separately
412+
debug_assert!(result.num_rows() > 0);
410413
Ok(result)
411414
}
412415

@@ -731,7 +734,7 @@ mod tests {
731734
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
732735
assert_eq!(
733736
result.iter().map(|r| r.num_rows()).collect_vec(),
734-
[0, 125, 125, 0, 150]
737+
[125, 125, 150]
735738
);
736739

737740
assert_eq!(
@@ -760,10 +763,10 @@ mod tests {
760763
nulls_first: false,
761764
};
762765
for (fetch_size, expected_batch_num_rows) in [
763-
(Some(50), vec![0, 50]),
764-
(Some(120), vec![0, 120]),
765-
(Some(150), vec![0, 125, 25]),
766-
(Some(250), vec![0, 125, 125]),
766+
(Some(50), vec![50]),
767+
(Some(120), vec![120]),
768+
(Some(150), vec![125, 25]),
769+
(Some(250), vec![125, 125]),
767770
] {
768771
let partial_sort_executor = PartialSortExec::new(
769772
LexOrdering::new(vec![
@@ -810,6 +813,42 @@ mod tests {
810813
Ok(())
811814
}
812815

816+
#[tokio::test]
817+
async fn test_partial_sort_no_empty_batches() -> Result<()> {
818+
let task_ctx = Arc::new(TaskContext::default());
819+
let mem_exec = prepare_partitioned_input();
820+
let schema = mem_exec.schema();
821+
let option_asc = SortOptions {
822+
descending: false,
823+
nulls_first: false,
824+
};
825+
let fetch_size = Some(250);
826+
let partial_sort_executor = PartialSortExec::new(
827+
LexOrdering::new(vec![
828+
PhysicalSortExpr {
829+
expr: col("a", &schema)?,
830+
options: option_asc,
831+
},
832+
PhysicalSortExpr {
833+
expr: col("c", &schema)?,
834+
options: option_asc,
835+
},
836+
]),
837+
Arc::clone(&mem_exec),
838+
1,
839+
)
840+
.with_fetch(fetch_size);
841+
842+
let partial_sort_exec =
843+
Arc::new(partial_sort_executor.clone()) as Arc<dyn ExecutionPlan>;
844+
let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?;
845+
for rb in result {
846+
assert!(rb.num_rows() > 0);
847+
}
848+
849+
Ok(())
850+
}
851+
813852
#[tokio::test]
814853
async fn test_sort_metadata() -> Result<()> {
815854
let task_ctx = Arc::new(TaskContext::default());

0 commit comments

Comments
 (0)