Skip to content

Commit bd3ee23

Browse files
authored
perf: improve performance of SortPreservingMergeExec operator (#722)
* perf: re-use Array comparators This commit stores built Arrow comparators for two arrays on each of the sort key cursors, resulting in a significant reduction in the cost associated with merging record batches using the `SortPreservingMerge` operator. Benchmarks improved as follows: ``` ⇒ critcmp master pr group master pr ----- ------ -- interleave_batches 1.83 623.8±12.41µs ? ?/sec 1.00 341.2±6.98µs ? ?/sec merge_batches_no_overlap_large 1.56 400.6±4.94µs ? ?/sec 1.00 256.3±6.57µs ? ?/sec merge_batches_no_overlap_small 1.63 425.1±24.88µs ? ?/sec 1.00 261.1±7.46µs ? ?/sec merge_batches_small_into_large 1.18 228.0±3.95µs ? ?/sec 1.00 193.6±2.86µs ? ?/sec merge_batches_some_overlap_large 1.68 505.4±10.27µs ? ?/sec 1.00 301.3±6.63µs ? ?/sec merge_batches_some_overlap_small 1.64 515.7±5.21µs ? ?/sec 1.00 314.6±12.66µs ? ?/sec ``` * test: test more than two partitions
1 parent afe29bd commit bd3ee23

File tree

1 file changed

+145
-37
lines changed

1 file changed

+145
-37
lines changed

datafusion/src/physical_plan/sort_preserving_merge.rs

Lines changed: 145 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::pin::Pin;
2424
use std::sync::Arc;
2525
use std::task::{Context, Poll};
2626

27+
use arrow::array::DynComparator;
2728
use arrow::{
2829
array::{make_array as make_arrow_array, ArrayRef, MutableArrayData},
2930
compute::SortOptions,
@@ -35,6 +36,7 @@ use async_trait::async_trait;
3536
use futures::channel::mpsc;
3637
use futures::stream::FusedStream;
3738
use futures::{Stream, StreamExt};
39+
use hashbrown::HashMap;
3840

3941
use crate::error::{DataFusionError, Result};
4042
use crate::physical_plan::{
@@ -176,34 +178,60 @@ impl ExecutionPlan for SortPreservingMergeExec {
176178
}
177179
}
178180

179-
/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when
180-
/// evaluated on the `RecordBatch` yield the sort keys.
181+
/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
182+
/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
181183
///
182184
/// Additionally it maintains a row cursor that can be advanced through the rows
183185
/// of the provided `RecordBatch`
184186
///
185-
/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this
186-
/// row cursor, with that of another `SortKeyCursor`
187-
#[derive(Debug, Clone)]
187+
/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
188+
/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
189+
/// a row comparator for each other cursor that it is compared to.
188190
struct SortKeyCursor {
189191
columns: Vec<ArrayRef>,
190-
batch: RecordBatch,
191192
cur_row: usize,
192193
num_rows: usize,
194+
195+
// An index uniquely identifying the record batch scanned by this cursor.
196+
batch_idx: usize,
197+
batch: RecordBatch,
198+
199+
// A collection of comparators that compare rows in this cursor's batch to
200+
// the cursors in other batches. Other batches are uniquely identified by
201+
// their batch_idx.
202+
batch_comparators: HashMap<usize, Vec<DynComparator>>,
203+
}
204+
205+
impl<'a> std::fmt::Debug for SortKeyCursor {
206+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207+
f.debug_struct("SortKeyCursor")
208+
.field("columns", &self.columns)
209+
.field("cur_row", &self.cur_row)
210+
.field("num_rows", &self.num_rows)
211+
.field("batch_idx", &self.batch_idx)
212+
.field("batch", &self.batch)
213+
.field("batch_comparators", &"<FUNC>")
214+
.finish()
215+
}
193216
}
194217

195218
impl SortKeyCursor {
196-
fn new(batch: RecordBatch, sort_key: &[Arc<dyn PhysicalExpr>]) -> Result<Self> {
219+
fn new(
220+
batch_idx: usize,
221+
batch: RecordBatch,
222+
sort_key: &[Arc<dyn PhysicalExpr>],
223+
) -> Result<Self> {
197224
let columns = sort_key
198225
.iter()
199226
.map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
200227
.collect::<Result<_>>()?;
201-
202228
Ok(Self {
203229
cur_row: 0,
204230
num_rows: batch.num_rows(),
205231
columns,
206232
batch,
233+
batch_idx,
234+
batch_comparators: HashMap::new(),
207235
})
208236
}
209237

@@ -220,7 +248,7 @@ impl SortKeyCursor {
220248

221249
/// Compares the sort key pointed to by this instance's row cursor with that of another
222250
fn compare(
223-
&self,
251+
&mut self,
224252
other: &SortKeyCursor,
225253
options: &[SortOptions],
226254
) -> Result<Ordering> {
@@ -246,7 +274,19 @@ impl SortKeyCursor {
246274
.zip(other.columns.iter())
247275
.zip(options.iter());
248276

249-
for ((l, r), sort_options) in zipped {
277+
// Recall or initialise a collection of comparators for comparing
278+
// columnar arrays of this cursor and "other".
279+
let cmp = self
280+
.batch_comparators
281+
.entry(other.batch_idx)
282+
.or_insert_with(|| Vec::with_capacity(other.columns.len()));
283+
284+
for (i, ((l, r), sort_options)) in zipped.enumerate() {
285+
if i >= cmp.len() {
286+
// initialise comparators as potentially needed
287+
cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?);
288+
}
289+
250290
match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
251291
(false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
252292
(false, true) => return Ok(Ordering::Greater),
@@ -255,15 +295,11 @@ impl SortKeyCursor {
255295
}
256296
(true, false) => return Ok(Ordering::Less),
257297
(false, false) => {}
258-
(true, true) => {
259-
// TODO: Building the predicate each time is sub-optimal
260-
let c = arrow::array::build_compare(l.as_ref(), r.as_ref())?;
261-
match c(self.cur_row, other.cur_row) {
262-
Ordering::Equal => {}
263-
o if sort_options.descending => return Ok(o.reverse()),
264-
o => return Ok(o),
265-
}
266-
}
298+
(true, true) => match cmp[i](self.cur_row, other.cur_row) {
299+
Ordering::Equal => {}
300+
o if sort_options.descending => return Ok(o.reverse()),
301+
o => return Ok(o),
302+
},
267303
}
268304
}
269305

@@ -304,6 +340,9 @@ struct SortPreservingMergeStream {
304340
target_batch_size: usize,
305341
/// If the stream has encountered an error
306342
aborted: bool,
343+
344+
/// An index to uniquely identify the input stream batch
345+
next_batch_index: usize,
307346
}
308347

309348
impl SortPreservingMergeStream {
@@ -313,15 +352,21 @@ impl SortPreservingMergeStream {
313352
expressions: &[PhysicalSortExpr],
314353
target_batch_size: usize,
315354
) -> Self {
355+
let cursors = (0..streams.len())
356+
.into_iter()
357+
.map(|_| VecDeque::new())
358+
.collect();
359+
316360
Self {
317361
schema,
318-
cursors: vec![Default::default(); streams.len()],
362+
cursors,
319363
streams,
320364
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
321365
sort_options: expressions.iter().map(|x| x.options).collect(),
322366
target_batch_size,
323367
aborted: false,
324368
in_progress: vec![],
369+
next_batch_index: 0,
325370
}
326371
}
327372

@@ -352,12 +397,17 @@ impl SortPreservingMergeStream {
352397
return Poll::Ready(Err(e));
353398
}
354399
Some(Ok(batch)) => {
355-
let cursor = match SortKeyCursor::new(batch, &self.column_expressions) {
400+
let cursor = match SortKeyCursor::new(
401+
self.next_batch_index, // assign this batch an ID
402+
batch,
403+
&self.column_expressions,
404+
) {
356405
Ok(cursor) => cursor,
357406
Err(e) => {
358407
return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
359408
}
360409
};
410+
self.next_batch_index += 1;
361411
self.cursors[idx].push_back(cursor)
362412
}
363413
}
@@ -367,17 +417,17 @@ impl SortPreservingMergeStream {
367417

368418
/// Returns the index of the next stream to pull a row from, or None
369419
/// if all cursors for all streams are exhausted
370-
fn next_stream_idx(&self) -> Result<Option<usize>> {
371-
let mut min_cursor: Option<(usize, &SortKeyCursor)> = None;
372-
for (idx, candidate) in self.cursors.iter().enumerate() {
373-
if let Some(candidate) = candidate.back() {
420+
fn next_stream_idx(&mut self) -> Result<Option<usize>> {
421+
let mut min_cursor: Option<(usize, &mut SortKeyCursor)> = None;
422+
for (idx, candidate) in self.cursors.iter_mut().enumerate() {
423+
if let Some(candidate) = candidate.back_mut() {
374424
if candidate.is_finished() {
375425
continue;
376426
}
377427

378428
match min_cursor {
379429
None => min_cursor = Some((idx, candidate)),
380-
Some((_, min)) => {
430+
Some((_, ref mut min)) => {
381431
if min.compare(candidate, &self.sort_options)?
382432
== Ordering::Greater
383433
{
@@ -599,8 +649,7 @@ mod tests {
599649
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
600650

601651
_test_merge(
602-
b1,
603-
b2,
652+
&[vec![b1], vec![b2]],
604653
&[
605654
"+----+---+-------------------------------+",
606655
"| a | b | c |",
@@ -646,8 +695,7 @@ mod tests {
646695
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
647696

648697
_test_merge(
649-
b1,
650-
b2,
698+
&[vec![b1], vec![b2]],
651699
&[
652700
"+-----+---+-------------------------------+",
653701
"| a | b | c |",
@@ -693,8 +741,7 @@ mod tests {
693741
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
694742

695743
_test_merge(
696-
b1,
697-
b2,
744+
&[vec![b1], vec![b2]],
698745
&[
699746
"+----+---+-------------------------------+",
700747
"| a | b | c |",
@@ -715,8 +762,71 @@ mod tests {
715762
.await;
716763
}
717764

718-
async fn _test_merge(b1: RecordBatch, b2: RecordBatch, exp: &[&str]) {
719-
let schema = b1.schema();
765+
#[tokio::test]
766+
async fn test_merge_three_partitions() {
767+
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
768+
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
769+
Some("a"),
770+
Some("b"),
771+
Some("c"),
772+
Some("d"),
773+
Some("f"),
774+
]));
775+
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
776+
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
777+
778+
let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
779+
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
780+
Some("e"),
781+
Some("g"),
782+
Some("h"),
783+
Some("i"),
784+
Some("j"),
785+
]));
786+
let c: ArrayRef =
787+
Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
788+
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
789+
790+
let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
791+
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
792+
Some("f"),
793+
Some("g"),
794+
Some("h"),
795+
Some("i"),
796+
Some("j"),
797+
]));
798+
let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
799+
let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
800+
801+
_test_merge(
802+
&[vec![b1], vec![b2], vec![b3]],
803+
&[
804+
"+-----+---+-------------------------------+",
805+
"| a | b | c |",
806+
"+-----+---+-------------------------------+",
807+
"| 1 | a | 1970-01-01 00:00:00.000000008 |",
808+
"| 2 | b | 1970-01-01 00:00:00.000000007 |",
809+
"| 7 | c | 1970-01-01 00:00:00.000000006 |",
810+
"| 9 | d | 1970-01-01 00:00:00.000000005 |",
811+
"| 10 | e | 1970-01-01 00:00:00.000000040 |",
812+
"| 100 | f | 1970-01-01 00:00:00.000000004 |",
813+
"| 3 | f | 1970-01-01 00:00:00.000000008 |",
814+
"| 200 | g | 1970-01-01 00:00:00.000000006 |",
815+
"| 20 | g | 1970-01-01 00:00:00.000000060 |",
816+
"| 700 | h | 1970-01-01 00:00:00.000000002 |",
817+
"| 70 | h | 1970-01-01 00:00:00.000000020 |",
818+
"| 900 | i | 1970-01-01 00:00:00.000000002 |",
819+
"| 90 | i | 1970-01-01 00:00:00.000000020 |",
820+
"| 300 | j | 1970-01-01 00:00:00.000000006 |",
821+
"| 30 | j | 1970-01-01 00:00:00.000000060 |",
822+
"+-----+---+-------------------------------+",
823+
],
824+
)
825+
.await;
826+
}
827+
828+
async fn _test_merge(partitions: &[Vec<RecordBatch>], exp: &[&str]) {
829+
let schema = partitions[0][0].schema();
720830
let sort = vec![
721831
PhysicalSortExpr {
722832
expr: col("b", &schema).unwrap(),
@@ -727,12 +837,10 @@ mod tests {
727837
options: Default::default(),
728838
},
729839
];
730-
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
840+
let exec = MemoryExec::try_new(partitions, schema, None).unwrap();
731841
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024));
732842

733843
let collected = collect(merge).await.unwrap();
734-
assert_eq!(collected.len(), 1);
735-
736844
assert_batches_eq!(exp, collected.as_slice());
737845
}
738846

0 commit comments

Comments
 (0)