Skip to content

Commit babe5cd

Browse files
committed
Address review comments
1 parent b652cee commit babe5cd

File tree

2 files changed

+204
-22
lines changed

2 files changed

+204
-22
lines changed

datafusion/core/tests/fuzz_cases/sort_fuzz.rs

Lines changed: 193 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use std::sync::Arc;
2121

2222
use arrow::{
23-
array::{ArrayRef, Int32Array},
23+
array::{as_string_array, ArrayRef, Int32Array, StringArray},
2424
compute::SortOptions,
2525
record_batch::RecordBatch,
2626
};
@@ -29,6 +29,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
2929
use datafusion::physical_plan::sorts::sort::SortExec;
3030
use datafusion::physical_plan::{collect, ExecutionPlan};
3131
use datafusion::prelude::{SessionConfig, SessionContext};
32+
use datafusion_common::cast::as_int32_array;
3233
use datafusion_execution::memory_pool::GreedyMemoryPool;
3334
use datafusion_physical_expr::expressions::col;
3435
use datafusion_physical_expr_common::sort_expr::LexOrdering;
@@ -42,12 +43,17 @@ const KB: usize = 1 << 10;
4243
#[cfg_attr(tarpaulin, ignore)]
4344
async fn test_sort_10k_mem() {
4445
for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] {
45-
SortTest::new()
46+
let (input, collected) = SortTest::new()
4647
.with_int32_batches(batch_size)
48+
.with_sort_columns(vec!["x"])
4749
.with_pool_size(10 * KB)
4850
.with_should_spill(should_spill)
4951
.run()
5052
.await;
53+
54+
let expected = partitions_to_sorted_vec(&input);
55+
let actual = batches_to_vec(&collected);
56+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
5157
}
5258
}
5359

@@ -57,29 +63,119 @@ async fn test_sort_100k_mem() {
5763
for (batch_size, should_spill) in
5864
[(5, false), (10000, false), (20000, true), (1000000, true)]
5965
{
60-
SortTest::new()
66+
let (input, collected) = SortTest::new()
6167
.with_int32_batches(batch_size)
68+
.with_sort_columns(vec!["x"])
69+
.with_pool_size(100 * KB)
70+
.with_should_spill(should_spill)
71+
.run()
72+
.await;
73+
74+
let expected = partitions_to_sorted_vec(&input);
75+
let actual = batches_to_vec(&collected);
76+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
77+
}
78+
}
79+
80+
#[tokio::test]
81+
#[cfg_attr(tarpaulin, ignore)]
82+
async fn test_sort_strings_100k_mem() {
83+
for (batch_size, should_spill) in
84+
[(5, false), (1000, false), (10000, true), (20000, true)]
85+
{
86+
let (input, collected) = SortTest::new()
87+
.with_utf8_batches(batch_size)
88+
.with_sort_columns(vec!["x"])
6289
.with_pool_size(100 * KB)
6390
.with_should_spill(should_spill)
6491
.run()
6592
.await;
93+
94+
let mut input = input
95+
.iter()
96+
.flat_map(|p| p.iter())
97+
.flat_map(|b| {
98+
let array = b.column(0);
99+
as_string_array(array)
100+
.iter()
101+
.map(|s| s.unwrap().to_string())
102+
})
103+
.collect::<Vec<String>>();
104+
input.sort_unstable();
105+
let actual = collected
106+
.iter()
107+
.flat_map(|b| {
108+
let array = b.column(0);
109+
as_string_array(array)
110+
.iter()
111+
.map(|s| s.unwrap().to_string())
112+
})
113+
.collect::<Vec<String>>();
114+
assert_eq!(input, actual);
115+
}
116+
}
117+
118+
#[tokio::test]
119+
#[cfg_attr(tarpaulin, ignore)]
120+
async fn test_sort_multi_columns_100k_mem() {
121+
for (batch_size, should_spill) in
122+
[(5, false), (1000, false), (10000, true), (20000, true)]
123+
{
124+
let (input, collected) = SortTest::new()
125+
.with_int32_utf8_batches(batch_size)
126+
.with_sort_columns(vec!["x", "y"])
127+
.with_pool_size(100 * KB)
128+
.with_should_spill(should_spill)
129+
.run()
130+
.await;
131+
132+
fn record_batch_to_vec(b: &RecordBatch) -> Vec<(i32, String)> {
133+
let mut rows: Vec<_> = Vec::new();
134+
let i32_array = as_int32_array(b.column(0)).unwrap();
135+
let string_array = as_string_array(b.column(1));
136+
for i in 0..b.num_rows() {
137+
let str = string_array.value(i).to_string();
138+
let i32 = i32_array.value(i);
139+
rows.push((i32, str));
140+
}
141+
rows
142+
}
143+
let mut input = input
144+
.iter()
145+
.flat_map(|p| p.iter())
146+
.flat_map(record_batch_to_vec)
147+
.collect::<Vec<(i32, String)>>();
148+
input.sort_unstable();
149+
let actual = collected
150+
.iter()
151+
.flat_map(record_batch_to_vec)
152+
.collect::<Vec<(i32, String)>>();
153+
assert_eq!(input, actual);
66154
}
67155
}
68156

69157
#[tokio::test]
70158
async fn test_sort_unlimited_mem() {
71159
for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, false)] {
72-
SortTest::new()
160+
let (input, collected) = SortTest::new()
73161
.with_int32_batches(batch_size)
162+
.with_sort_columns(vec!["x"])
74163
.with_pool_size(usize::MAX)
75164
.with_should_spill(should_spill)
76165
.run()
77166
.await;
167+
168+
let expected = partitions_to_sorted_vec(&input);
169+
let actual = batches_to_vec(&collected);
170+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
78171
}
79172
}
173+
80174
#[derive(Debug, Default)]
81175
struct SortTest {
82176
input: Vec<Vec<RecordBatch>>,
177+
/// The names of the columns to sort by
178+
sort_columns: Vec<String>,
83179
/// GreedyMemoryPool size, if specified
84180
pool_size: Option<usize>,
85181
/// If true, expect the sort to spill
@@ -91,12 +187,29 @@ impl SortTest {
91187
Default::default()
92188
}
93189

190+
fn with_sort_columns(mut self, sort_columns: Vec<&str>) -> Self {
191+
self.sort_columns = sort_columns.iter().map(|s| s.to_string()).collect();
192+
self
193+
}
194+
94195
/// Create batches of int32 values of rows
95196
fn with_int32_batches(mut self, rows: usize) -> Self {
96197
self.input = vec![make_staggered_i32_batches(rows)];
97198
self
98199
}
99200

201+
/// Create batches of utf8 values of rows
202+
fn with_utf8_batches(mut self, rows: usize) -> Self {
203+
self.input = vec![make_staggered_utf8_batches(rows)];
204+
self
205+
}
206+
207+
/// Create batches of int32 and utf8 values of rows
208+
fn with_int32_utf8_batches(mut self, rows: usize) -> Self {
209+
self.input = vec![make_staggered_i32_utf8_batches(rows)];
210+
self
211+
}
212+
100213
/// specify that this test should use a memory pool of the specified size
101214
fn with_pool_size(mut self, pool_size: usize) -> Self {
102215
self.pool_size = Some(pool_size);
@@ -110,7 +223,7 @@ impl SortTest {
110223

111224
/// Sort the input using SortExec and ensure the results are
112225
/// correct according to `Vec::sort` both with and without spilling
113-
async fn run(&self) {
226+
async fn run(&self) -> (Vec<Vec<RecordBatch>>, Vec<RecordBatch>) {
114227
let input = self.input.clone();
115228
let first_batch = input
116229
.iter()
@@ -119,16 +232,21 @@ impl SortTest {
119232
.expect("at least one batch");
120233
let schema = first_batch.schema();
121234

122-
let sort = LexOrdering::new(vec![PhysicalSortExpr {
123-
expr: col("x", &schema).unwrap(),
124-
options: SortOptions {
125-
descending: false,
126-
nulls_first: true,
127-
},
128-
}]);
235+
let sort_ordering = LexOrdering::new(
236+
self.sort_columns
237+
.iter()
238+
.map(|c| PhysicalSortExpr {
239+
expr: col(c, &schema).unwrap(),
240+
options: SortOptions {
241+
descending: false,
242+
nulls_first: true,
243+
},
244+
})
245+
.collect(),
246+
);
129247

130248
let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap();
131-
let sort = Arc::new(SortExec::new(sort, exec));
249+
let sort = Arc::new(SortExec::new(sort_ordering, exec));
132250

133251
let session_config = SessionConfig::new();
134252
let session_ctx = if let Some(pool_size) = self.pool_size {
@@ -153,9 +271,6 @@ impl SortTest {
153271
let task_ctx = session_ctx.task_ctx();
154272
let collected = collect(sort.clone(), task_ctx).await.unwrap();
155273

156-
let expected = partitions_to_sorted_vec(&input);
157-
let actual = batches_to_vec(&collected);
158-
159274
if self.should_spill {
160275
assert_ne!(
161276
sort.metrics().unwrap().spill_count().unwrap(),
@@ -175,7 +290,8 @@ impl SortTest {
175290
0,
176291
"The sort should have returned all memory used back to the memory pool"
177292
);
178-
assert_eq!(expected, actual, "failure in @ pool_size {self:?}");
293+
294+
(input, collected)
179295
}
180296
}
181297

@@ -203,3 +319,63 @@ fn make_staggered_i32_batches(len: usize) -> Vec<RecordBatch> {
203319
}
204320
batches
205321
}
322+
323+
/// Return randomly sized record batches in a field named 'x' of type `Utf8`
324+
/// with randomized content
325+
fn make_staggered_utf8_batches(len: usize) -> Vec<RecordBatch> {
326+
let mut rng = rand::thread_rng();
327+
let max_batch = 1024;
328+
329+
let mut batches = vec![];
330+
let mut remaining = len;
331+
while remaining != 0 {
332+
let to_read = rng.gen_range(0..=remaining.min(max_batch));
333+
remaining -= to_read;
334+
335+
batches.push(
336+
RecordBatch::try_from_iter(vec![(
337+
"x",
338+
Arc::new(StringArray::from_iter_values(
339+
(0..to_read).map(|_| format!("test_string_{}", rng.gen::<u32>())),
340+
)) as ArrayRef,
341+
)])
342+
.unwrap(),
343+
)
344+
}
345+
batches
346+
}
347+
348+
/// Return randomly sized record batches in a field named 'x' of type `Int32`
349+
/// with randomized i32 content and a field named 'y' of type `Utf8`
350+
/// with randomized content
351+
fn make_staggered_i32_utf8_batches(len: usize) -> Vec<RecordBatch> {
352+
let mut rng = rand::thread_rng();
353+
let max_batch = 1024;
354+
355+
let mut batches = vec![];
356+
let mut remaining = len;
357+
while remaining != 0 {
358+
let to_read = rng.gen_range(0..=remaining.min(max_batch));
359+
remaining -= to_read;
360+
361+
batches.push(
362+
RecordBatch::try_from_iter(vec![
363+
(
364+
"x",
365+
Arc::new(Int32Array::from_iter_values(
366+
(0..to_read).map(|_| rng.gen()),
367+
)) as ArrayRef,
368+
),
369+
(
370+
"y",
371+
Arc::new(StringArray::from_iter_values(
372+
(0..to_read).map(|_| format!("test_string_{}", rng.gen::<u32>())),
373+
)) as ArrayRef,
374+
),
375+
])
376+
.unwrap(),
377+
)
378+
}
379+
380+
batches
381+
}

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ struct ExternalSorter {
225225
// ========================================================================
226226
/// Potentially unsorted in memory buffer
227227
in_mem_batches: Vec<RecordBatch>,
228+
/// if `Self::in_mem_batches` are sorted
229+
in_mem_batches_sorted: bool,
228230

229231
/// If data has previously been spilled, the locations of the
230232
/// spill files (in Arrow IPC format)
@@ -277,6 +279,7 @@ impl ExternalSorter {
277279
Self {
278280
schema,
279281
in_mem_batches: vec![],
282+
in_mem_batches_sorted: false,
280283
spills: vec![],
281284
expr: expr.into(),
282285
metrics,
@@ -309,6 +312,7 @@ impl ExternalSorter {
309312
}
310313

311314
self.in_mem_batches.push(input);
315+
self.in_mem_batches_sorted = false;
312316
Ok(())
313317
}
314318

@@ -423,7 +427,8 @@ impl ExternalSorter {
423427
async fn sort_or_spill_in_mem_batches(&mut self) -> Result<()> {
424428
// Release the memory reserved for merge back to the pool so
425429
// there is some left when `in_mem_sort_stream` requests an
426-
// allocation.
430+
// allocation. At the end of this function, memory will be
431+
// reserved again for the next spill.
427432
self.merge_reservation.free();
428433

429434
let before = self.reservation.size();
@@ -458,6 +463,7 @@ impl ExternalSorter {
458463
self.spills.push(spill_file);
459464
} else {
460465
self.in_mem_batches.push(batch);
466+
self.in_mem_batches_sorted = true;
461467
}
462468
}
463469
Some(writer) => {
@@ -662,10 +668,10 @@ impl ExternalSorter {
662668
/// Estimate how much memory is needed to sort a `RecordBatch`.
663669
///
664670
/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves
665-
/// creating sorted copies of sorted columns in record batches, the sorted copies could be
666-
/// in either row format or array format. Please refer to cursor.rs and stream.rs for more
667-
/// details. No matter what format the sorted copies are, they will use more memory than
668-
/// the original record batch.
671+
/// creating sorted copies of sorted columns in record batches for speeding up comparison
672+
/// in sorting and merging. The sorted copies are in either row format or array format.
673+
/// Please refer to cursor.rs and stream.rs for more details. No matter what format the
674+
/// sorted copies are, they will use more memory than the original record batch.
669675
fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize {
670676
// 2x may not be enough for some cases, but it's a good start.
671677
// If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes`

0 commit comments

Comments
 (0)