Skip to content

Commit 2266474

Browse files
alambyjshen
andauthored
Fix bug while merging RecordBatch, add SortPreservingMerge fuzz tester (#1678)
* skip empty batch while inserting * `SortPreservingMerge` fuzz testing Co-authored-by: Yijie Shen <henry.yijieshen@gmail.com>
1 parent bf71577 commit 2266474

File tree

2 files changed

+264
-28
lines changed

2 files changed

+264
-28
lines changed

datafusion/src/physical_plan/sorts/sort_preserving_merge.rs

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -410,40 +410,53 @@ impl SortPreservingMergeStream {
410410
// Cursor is not finished - don't need a new RecordBatch yet
411411
return Poll::Ready(Ok(()));
412412
}
413-
let mut streams = self.streams.streams.lock().unwrap();
413+
let mut empty_batch = false;
414+
{
415+
let mut streams = self.streams.streams.lock().unwrap();
414416

415-
let stream = &mut streams[idx];
416-
if stream.is_terminated() {
417-
return Poll::Ready(Ok(()));
418-
}
419-
420-
// Fetch a new input record and create a cursor from it
421-
match futures::ready!(stream.poll_next_unpin(cx)) {
422-
None => return Poll::Ready(Ok(())),
423-
Some(Err(e)) => {
424-
return Poll::Ready(Err(e));
417+
let stream = &mut streams[idx];
418+
if stream.is_terminated() {
419+
return Poll::Ready(Ok(()));
425420
}
426-
Some(Ok(batch)) => {
427-
let cursor = match SortKeyCursor::new(
428-
idx,
429-
self.next_batch_id, // assign this batch an ID
430-
&batch,
431-
&self.column_expressions,
432-
self.sort_options.clone(),
433-
) {
434-
Ok(cursor) => cursor,
435-
Err(e) => {
436-
return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
421+
422+
// Fetch a new input record and create a cursor from it
423+
match futures::ready!(stream.poll_next_unpin(cx)) {
424+
None => return Poll::Ready(Ok(())),
425+
Some(Err(e)) => {
426+
return Poll::Ready(Err(e));
427+
}
428+
Some(Ok(batch)) => {
429+
if batch.num_rows() > 0 {
430+
let cursor = match SortKeyCursor::new(
431+
idx,
432+
self.next_batch_id, // assign this batch an ID
433+
&batch,
434+
&self.column_expressions,
435+
self.sort_options.clone(),
436+
) {
437+
Ok(cursor) => cursor,
438+
Err(e) => {
439+
return Poll::Ready(Err(ArrowError::ExternalError(
440+
Box::new(e),
441+
)));
442+
}
443+
};
444+
self.next_batch_id += 1;
445+
self.min_heap.push(cursor);
446+
self.cursor_finished[idx] = false;
447+
self.batches[idx].push_back(batch)
448+
} else {
449+
empty_batch = true;
437450
}
438-
};
439-
self.next_batch_id += 1;
440-
self.min_heap.push(cursor);
441-
self.cursor_finished[idx] = false;
442-
self.batches[idx].push_back(batch)
451+
}
443452
}
444453
}
445454

446-
Poll::Ready(Ok(()))
455+
if empty_batch {
456+
self.maybe_poll_stream(cx, idx)
457+
} else {
458+
Poll::Ready(Ok(()))
459+
}
447460
}
448461

449462
/// Drains the in_progress row indexes, and builds a new RecordBatch from them

datafusion/tests/merge_fuzz.rs

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Fuzz Test for various corner cases merging streams of RecordBatchs
19+
use std::sync::Arc;
20+
21+
use arrow::{
22+
array::{ArrayRef, Int32Array},
23+
compute::SortOptions,
24+
record_batch::RecordBatch,
25+
};
26+
use datafusion::{
27+
execution::runtime_env::{RuntimeConfig, RuntimeEnv},
28+
physical_plan::{
29+
collect,
30+
expressions::{col, PhysicalSortExpr},
31+
memory::MemoryExec,
32+
sorts::sort_preserving_merge::SortPreservingMergeExec,
33+
},
34+
};
35+
use rand::{prelude::StdRng, Rng, SeedableRng};
36+
37+
#[tokio::test]
38+
async fn test_merge_2() {
39+
run_merge_test(vec![
40+
// (0..100)
41+
// (0..100)
42+
make_staggered_batches(0, 100, 2),
43+
make_staggered_batches(0, 100, 3),
44+
])
45+
.await
46+
}
47+
48+
#[tokio::test]
49+
async fn test_merge_2_no_overlap() {
50+
run_merge_test(vec![
51+
// (0..20)
52+
// (20..40)
53+
make_staggered_batches(0, 20, 2),
54+
make_staggered_batches(20, 40, 3),
55+
])
56+
.await
57+
}
58+
59+
#[tokio::test]
60+
async fn test_merge_3() {
61+
run_merge_test(vec![
62+
// (0 .. 100)
63+
// (0 .. 100)
64+
// (0 .. 51)
65+
make_staggered_batches(0, 100, 2),
66+
make_staggered_batches(0, 100, 3),
67+
make_staggered_batches(0, 51, 4),
68+
])
69+
.await
70+
}
71+
72+
#[tokio::test]
73+
async fn test_merge_3_gaps() {
74+
run_merge_test(vec![
75+
// (0 .. 50)(50 .. 100)
76+
// (0 ..33) (50 .. 100)
77+
// (0 .. 51)
78+
concat(
79+
make_staggered_batches(0, 50, 2),
80+
make_staggered_batches(50, 100, 7),
81+
),
82+
concat(
83+
make_staggered_batches(0, 33, 21),
84+
make_staggered_batches(50, 123, 31),
85+
),
86+
make_staggered_batches(0, 51, 11),
87+
])
88+
.await
89+
}
90+
91+
/// Merge a set of input streams using SortPreservingMergeExec and
92+
/// `Vec::sort` and ensure the results are the same.
93+
///
94+
/// For each case, the `input` streams are turned into a set of of
95+
/// streams which are then merged together by [SortPreservingMerge]
96+
///
97+
/// Each `Vec<RecordBatch>` in `input` must be sorted and have a
98+
/// single Int32 field named 'x'.
99+
async fn run_merge_test(input: Vec<Vec<RecordBatch>>) {
100+
// Produce output with the specified output batch sizes
101+
let batch_sizes = [1, 2, 7, 49, 50, 51, 100];
102+
103+
for batch_size in batch_sizes {
104+
let first_batch = input
105+
.iter()
106+
.map(|p| p.iter())
107+
.flatten()
108+
.next()
109+
.expect("at least one batch");
110+
let schema = first_batch.schema();
111+
112+
let sort = vec![PhysicalSortExpr {
113+
expr: col("x", &schema).unwrap(),
114+
options: SortOptions {
115+
descending: false,
116+
nulls_first: true,
117+
},
118+
}];
119+
120+
let exec = MemoryExec::try_new(&input, schema, None).unwrap();
121+
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec)));
122+
123+
let runtime_config = RuntimeConfig::new().with_batch_size(batch_size);
124+
125+
let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap());
126+
let collected = collect(merge, runtime).await.unwrap();
127+
128+
// verify the output batch size: all batches except the last
129+
// should contain `batch_size` rows
130+
for (i, batch) in collected.iter().enumerate() {
131+
if i < collected.len() - 1 {
132+
assert_eq!(
133+
batch.num_rows(),
134+
batch_size,
135+
"Expected batch {} to have {} rows, got {}",
136+
i,
137+
batch_size,
138+
batch.num_rows()
139+
);
140+
}
141+
}
142+
143+
let expected = partitions_to_sorted_vec(&input);
144+
let actual = batches_to_vec(&collected);
145+
146+
assert_eq!(expected, actual, "failure in @ batch_size {}", batch_size);
147+
}
148+
}
149+
150+
/// Extracts the i32 values from the set of batches and returns them as a single Vec
151+
fn batches_to_vec(batches: &[RecordBatch]) -> Vec<Option<i32>> {
152+
batches
153+
.iter()
154+
.map(|batch| {
155+
assert_eq!(batch.num_columns(), 1);
156+
batch
157+
.column(0)
158+
.as_any()
159+
.downcast_ref::<Int32Array>()
160+
.unwrap()
161+
.iter()
162+
})
163+
.flatten()
164+
.collect()
165+
}
166+
167+
// extract values from batches and sort them
168+
fn partitions_to_sorted_vec(partitions: &[Vec<RecordBatch>]) -> Vec<Option<i32>> {
169+
let mut values: Vec<_> = partitions
170+
.iter()
171+
.map(|batches| batches_to_vec(batches).into_iter())
172+
.flatten()
173+
.collect();
174+
175+
values.sort_unstable();
176+
values
177+
}
178+
179+
/// Return the values `low..high` in order, in randomly sized
180+
/// record batches in a field named 'x' of type `Int32`
181+
fn make_staggered_batches(low: i32, high: i32, seed: u64) -> Vec<RecordBatch> {
182+
let input: Int32Array = (low..high).map(Some).collect();
183+
184+
// split into several record batches
185+
let mut remainder =
186+
RecordBatch::try_from_iter(vec![("x", Arc::new(input) as ArrayRef)]).unwrap();
187+
188+
let mut batches = vec![];
189+
190+
// use a random number generator to pick a random sized output
191+
let mut rng = StdRng::seed_from_u64(seed);
192+
while remainder.num_rows() > 0 {
193+
let batch_size = rng.gen_range(0..remainder.num_rows() + 1);
194+
195+
batches.push(remainder.slice(0, batch_size));
196+
remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size);
197+
}
198+
199+
add_empty_batches(batches, &mut rng)
200+
}
201+
202+
/// Adds a random number of empty record batches into the stream
203+
fn add_empty_batches(batches: Vec<RecordBatch>, rng: &mut StdRng) -> Vec<RecordBatch> {
204+
let schema = batches[0].schema();
205+
206+
batches
207+
.into_iter()
208+
.map(|batch| {
209+
// insert 0, or 1 empty batches before and after the current batch
210+
let empty_batch = RecordBatch::new_empty(schema.clone());
211+
std::iter::repeat(empty_batch.clone())
212+
.take(rng.gen_range(0..2))
213+
.chain(std::iter::once(batch))
214+
.chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2)))
215+
})
216+
.flatten()
217+
.collect()
218+
}
219+
220+
fn concat(mut v1: Vec<RecordBatch>, v2: Vec<RecordBatch>) -> Vec<RecordBatch> {
221+
v1.extend(v2);
222+
v1
223+
}

0 commit comments

Comments
 (0)