-
Notifications
You must be signed in to change notification settings - Fork 1.8k
perf: improve performance of SortPreservingMergeExec operator
#722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ use std::pin::Pin; | |
| use std::sync::Arc; | ||
| use std::task::{Context, Poll}; | ||
|
|
||
| use arrow::array::DynComparator; | ||
| use arrow::{ | ||
| array::{make_array as make_arrow_array, ArrayRef, MutableArrayData}, | ||
| compute::SortOptions, | ||
|
|
@@ -35,6 +36,7 @@ use async_trait::async_trait; | |
| use futures::channel::mpsc; | ||
| use futures::stream::FusedStream; | ||
| use futures::{Stream, StreamExt}; | ||
| use hashbrown::HashMap; | ||
|
|
||
| use crate::error::{DataFusionError, Result}; | ||
| use crate::physical_plan::{ | ||
|
|
@@ -176,34 +178,60 @@ impl ExecutionPlan for SortPreservingMergeExec { | |
| } | ||
| } | ||
|
|
||
| /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when | ||
| /// evaluated on the `RecordBatch` yield the sort keys. | ||
| /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of | ||
| /// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. | ||
| /// | ||
| /// Additionally it maintains a row cursor that can be advanced through the rows | ||
| /// of the provided `RecordBatch` | ||
| /// | ||
| /// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this | ||
| /// row cursor, with that of another `SortKeyCursor` | ||
| #[derive(Debug, Clone)] | ||
| /// `SortKeyCursor::compare` can then be used to compare the sort key pointed to | ||
| /// by this row cursor, with that of another `SortKeyCursor`. A cursor stores | ||
| /// a row comparator for each other cursor that it is compared to. | ||
| struct SortKeyCursor { | ||
| columns: Vec<ArrayRef>, | ||
| batch: RecordBatch, | ||
| cur_row: usize, | ||
| num_rows: usize, | ||
|
|
||
| // An index uniquely identifying the record batch scanned by this cursor. | ||
| batch_idx: usize, | ||
| batch: RecordBatch, | ||
|
|
||
| // A collection of comparators that compare rows in this cursor's batch to | ||
| // the cursors in other batches. Other batches are uniquely identified by | ||
| // their batch_idx. | ||
| batch_comparators: HashMap<usize, Vec<DynComparator>>, | ||
| } | ||
|
|
||
| impl<'a> std::fmt::Debug for SortKeyCursor { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.debug_struct("SortKeyCursor") | ||
| .field("columns", &self.columns) | ||
| .field("cur_row", &self.cur_row) | ||
| .field("num_rows", &self.num_rows) | ||
| .field("batch_idx", &self.batch_idx) | ||
| .field("batch", &self.batch) | ||
| .field("batch_comparators", &"<FUNC>") | ||
| .finish() | ||
| } | ||
| } | ||
|
|
||
| impl SortKeyCursor { | ||
| fn new(batch: RecordBatch, sort_key: &[Arc<dyn PhysicalExpr>]) -> Result<Self> { | ||
| fn new( | ||
| batch_idx: usize, | ||
| batch: RecordBatch, | ||
| sort_key: &[Arc<dyn PhysicalExpr>], | ||
| ) -> Result<Self> { | ||
| let columns = sort_key | ||
| .iter() | ||
| .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) | ||
| .collect::<Result<_>>()?; | ||
|
|
||
| Ok(Self { | ||
| cur_row: 0, | ||
| num_rows: batch.num_rows(), | ||
| columns, | ||
| batch, | ||
| batch_idx, | ||
| batch_comparators: HashMap::new(), | ||
| }) | ||
| } | ||
|
|
||
|
|
@@ -220,7 +248,7 @@ impl SortKeyCursor { | |
|
|
||
| /// Compares the sort key pointed to by this instance's row cursor with that of another | ||
| fn compare( | ||
| &self, | ||
| &mut self, | ||
| other: &SortKeyCursor, | ||
| options: &[SortOptions], | ||
| ) -> Result<Ordering> { | ||
|
|
@@ -246,7 +274,19 @@ impl SortKeyCursor { | |
| .zip(other.columns.iter()) | ||
| .zip(options.iter()); | ||
|
|
||
| for ((l, r), sort_options) in zipped { | ||
| // Recall or initialise a collection of comparators for comparing | ||
| // columnar arrays of this cursor and "other". | ||
| let cmp = self | ||
| .batch_comparators | ||
| .entry(other.batch_idx) | ||
| .or_insert_with(|| Vec::with_capacity(other.columns.len())); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was initially worried that mutating However, perhaps that avoids having to clone or collect Please feel free to ignore this comment
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha. Yeah that's how I had it but you have to re-create the iterator. I'm happy to go with whatever
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this way is fine, personally
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we need to clean older
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean to reduce the peak memory usage by freeing comparators?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The HashMap only lives for as long as the lifetime of the sort preserving merge operator. I suppose you could merge n streams large enough to make the old comparators take up non-negligible memory.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's what I meant. The memory usage would increase over time by holding onto the older comparators.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I personally think the "right" thing to do in this case is to create an arrow comparator interface that doesn't have bound arrays, so like With that interface we could simply use the same comparator So that is to say, I recommend leaving this PR the way it is and putting our efforts into a better comparison interface rather than trying to optimize for a small amount of memory savings here |
||
|
|
||
| for (i, ((l, r), sort_options)) in zipped.enumerate() { | ||
| if i >= cmp.len() { | ||
| // initialise comparators as potentially needed | ||
| cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?); | ||
| } | ||
|
|
||
| match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { | ||
| (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), | ||
| (false, true) => return Ok(Ordering::Greater), | ||
|
|
@@ -255,15 +295,11 @@ impl SortKeyCursor { | |
| } | ||
| (true, false) => return Ok(Ordering::Less), | ||
| (false, false) => {} | ||
| (true, true) => { | ||
| // TODO: Building the predicate each time is sub-optimal | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉 |
||
| let c = arrow::array::build_compare(l.as_ref(), r.as_ref())?; | ||
| match c(self.cur_row, other.cur_row) { | ||
| Ordering::Equal => {} | ||
| o if sort_options.descending => return Ok(o.reverse()), | ||
| o => return Ok(o), | ||
| } | ||
| } | ||
| (true, true) => match cmp[i](self.cur_row, other.cur_row) { | ||
| Ordering::Equal => {} | ||
| o if sort_options.descending => return Ok(o.reverse()), | ||
| o => return Ok(o), | ||
| }, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -304,6 +340,9 @@ struct SortPreservingMergeStream { | |
| target_batch_size: usize, | ||
| /// If the stream has encountered an error | ||
| aborted: bool, | ||
|
|
||
| /// An index to uniquely identify the input stream batch | ||
| next_batch_index: usize, | ||
| } | ||
|
|
||
| impl SortPreservingMergeStream { | ||
|
|
@@ -313,15 +352,21 @@ impl SortPreservingMergeStream { | |
| expressions: &[PhysicalSortExpr], | ||
| target_batch_size: usize, | ||
| ) -> Self { | ||
| let cursors = (0..streams.len()) | ||
| .into_iter() | ||
| .map(|_| VecDeque::new()) | ||
| .collect(); | ||
|
|
||
| Self { | ||
| schema, | ||
| cursors: vec![Default::default(); streams.len()], | ||
| cursors, | ||
| streams, | ||
| column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), | ||
| sort_options: expressions.iter().map(|x| x.options).collect(), | ||
| target_batch_size, | ||
| aborted: false, | ||
| in_progress: vec![], | ||
| next_batch_index: 0, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -352,12 +397,17 @@ impl SortPreservingMergeStream { | |
| return Poll::Ready(Err(e)); | ||
| } | ||
| Some(Ok(batch)) => { | ||
| let cursor = match SortKeyCursor::new(batch, &self.column_expressions) { | ||
| let cursor = match SortKeyCursor::new( | ||
| self.next_batch_index, // assign this batch an ID | ||
| batch, | ||
| &self.column_expressions, | ||
| ) { | ||
| Ok(cursor) => cursor, | ||
| Err(e) => { | ||
| return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e)))); | ||
| } | ||
| }; | ||
| self.next_batch_index += 1; | ||
| self.cursors[idx].push_back(cursor) | ||
| } | ||
| } | ||
|
|
@@ -367,17 +417,17 @@ impl SortPreservingMergeStream { | |
|
|
||
| /// Returns the index of the next stream to pull a row from, or None | ||
| /// if all cursors for all streams are exhausted | ||
| fn next_stream_idx(&self) -> Result<Option<usize>> { | ||
| let mut min_cursor: Option<(usize, &SortKeyCursor)> = None; | ||
| for (idx, candidate) in self.cursors.iter().enumerate() { | ||
| if let Some(candidate) = candidate.back() { | ||
| fn next_stream_idx(&mut self) -> Result<Option<usize>> { | ||
| let mut min_cursor: Option<(usize, &mut SortKeyCursor)> = None; | ||
| for (idx, candidate) in self.cursors.iter_mut().enumerate() { | ||
| if let Some(candidate) = candidate.back_mut() { | ||
| if candidate.is_finished() { | ||
| continue; | ||
| } | ||
|
|
||
| match min_cursor { | ||
| None => min_cursor = Some((idx, candidate)), | ||
| Some((_, min)) => { | ||
| Some((_, ref mut min)) => { | ||
| if min.compare(candidate, &self.sort_options)? | ||
| == Ordering::Greater | ||
| { | ||
|
|
@@ -599,8 +649,7 @@ mod tests { | |
| let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| _test_merge( | ||
| b1, | ||
| b2, | ||
| &[vec![b1], vec![b2]], | ||
| &[ | ||
| "+----+---+-------------------------------+", | ||
| "| a | b | c |", | ||
|
|
@@ -646,8 +695,7 @@ mod tests { | |
| let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| _test_merge( | ||
| b1, | ||
| b2, | ||
| &[vec![b1], vec![b2]], | ||
| &[ | ||
| "+-----+---+-------------------------------+", | ||
| "| a | b | c |", | ||
|
|
@@ -693,8 +741,7 @@ mod tests { | |
| let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| _test_merge( | ||
| b1, | ||
| b2, | ||
| &[vec![b1], vec![b2]], | ||
| &[ | ||
| "+----+---+-------------------------------+", | ||
| "| a | b | c |", | ||
|
|
@@ -715,8 +762,71 @@ mod tests { | |
| .await; | ||
| } | ||
|
|
||
| async fn _test_merge(b1: RecordBatch, b2: RecordBatch, exp: &[&str]) { | ||
| let schema = b1.schema(); | ||
| #[tokio::test] | ||
| async fn test_merge_three_partitions() { | ||
| let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); | ||
| let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ | ||
| Some("a"), | ||
| Some("b"), | ||
| Some("c"), | ||
| Some("d"), | ||
| Some("f"), | ||
| ])); | ||
| let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8])); | ||
| let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30])); | ||
| let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ | ||
| Some("e"), | ||
| Some("g"), | ||
| Some("h"), | ||
| Some("i"), | ||
| Some("j"), | ||
| ])); | ||
| let c: ArrayRef = | ||
| Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60])); | ||
| let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300])); | ||
| let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ | ||
| Some("f"), | ||
| Some("g"), | ||
| Some("h"), | ||
| Some("i"), | ||
| Some("j"), | ||
| ])); | ||
| let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6])); | ||
| let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); | ||
|
|
||
| _test_merge( | ||
| &[vec![b1], vec![b2], vec![b3]], | ||
| &[ | ||
| "+-----+---+-------------------------------+", | ||
| "| a | b | c |", | ||
| "+-----+---+-------------------------------+", | ||
| "| 1 | a | 1970-01-01 00:00:00.000000008 |", | ||
| "| 2 | b | 1970-01-01 00:00:00.000000007 |", | ||
| "| 7 | c | 1970-01-01 00:00:00.000000006 |", | ||
| "| 9 | d | 1970-01-01 00:00:00.000000005 |", | ||
| "| 10 | e | 1970-01-01 00:00:00.000000040 |", | ||
| "| 100 | f | 1970-01-01 00:00:00.000000004 |", | ||
| "| 3 | f | 1970-01-01 00:00:00.000000008 |", | ||
| "| 200 | g | 1970-01-01 00:00:00.000000006 |", | ||
| "| 20 | g | 1970-01-01 00:00:00.000000060 |", | ||
| "| 700 | h | 1970-01-01 00:00:00.000000002 |", | ||
| "| 70 | h | 1970-01-01 00:00:00.000000020 |", | ||
| "| 900 | i | 1970-01-01 00:00:00.000000002 |", | ||
| "| 90 | i | 1970-01-01 00:00:00.000000020 |", | ||
| "| 300 | j | 1970-01-01 00:00:00.000000006 |", | ||
| "| 30 | j | 1970-01-01 00:00:00.000000060 |", | ||
| "+-----+---+-------------------------------+", | ||
| ], | ||
| ) | ||
| .await; | ||
| } | ||
|
|
||
| async fn _test_merge(partitions: &[Vec<RecordBatch>], exp: &[&str]) { | ||
| let schema = partitions[0][0].schema(); | ||
| let sort = vec![ | ||
| PhysicalSortExpr { | ||
| expr: col("b", &schema).unwrap(), | ||
|
|
@@ -727,12 +837,10 @@ mod tests { | |
| options: Default::default(), | ||
| }, | ||
| ]; | ||
| let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); | ||
| let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); | ||
| let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); | ||
|
|
||
| let collected = collect(merge).await.unwrap(); | ||
| assert_eq!(collected.len(), 1); | ||
|
|
||
| assert_batches_eq!(exp, collected.as_slice()); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it worth mentioning that the index of the
Vecare the sort column positions (not other batch indexes)?