diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs new file mode 100644 index 000000000000..ebe4f95e2095 --- /dev/null +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::PhysicalExpr; +use arrow::array::{ArrayRef, DynComparator}; +use arrow::compute::SortOptions; +use arrow::record_batch::RecordBatch; +use hashbrown::HashMap; +use parking_lot::RwLock; +use std::borrow::BorrowMut; +use std::cmp::Ordering; +use std::sync::Arc; + +/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of +/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. +/// +/// Additionally it maintains a row cursor that can be advanced through the rows +/// of the provided `RecordBatch` +/// +/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to +/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores +/// a row comparator for each other cursor that it is compared to. +pub struct SortKeyCursor { + stream_idx: usize, + sort_columns: Vec, + cur_row: usize, + num_rows: usize, + + // An id uniquely identifying the record batch scanned by this cursor. + batch_id: usize, + + // A collection of comparators that compare rows in this cursor's batch to + // the cursors in other batches. Other batches are uniquely identified by + // their batch_idx. + batch_comparators: RwLock>>, + sort_options: Arc>, +} + +impl<'a> std::fmt::Debug for SortKeyCursor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("SortKeyCursor") + .field("sort_columns", &self.sort_columns) + .field("cur_row", &self.cur_row) + .field("num_rows", &self.num_rows) + .field("batch_id", &self.batch_id) + .field("batch_comparators", &"") + .finish() + } +} + +impl SortKeyCursor { + /// Create a new SortKeyCursor + pub fn new( + stream_idx: usize, + batch_id: usize, + batch: &RecordBatch, + sort_key: &[Arc], + sort_options: Arc>, + ) -> error::Result { + let sort_columns = sort_key + .iter() + .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) + .collect::>()?; + Ok(Self { + stream_idx, + cur_row: 0, + num_rows: batch.num_rows(), + sort_columns, + batch_id, + batch_comparators: RwLock::new(HashMap::new()), + sort_options, + }) + } + + #[inline(always)] + /// Return the stream index of this cursor + pub fn stream_idx(&self) -> usize { + self.stream_idx + } + + #[inline(always)] + /// Return the batch id of this cursor + pub fn batch_id(&self) -> usize { + self.batch_id + } + + #[inline(always)] + /// Return true if the stream is finished + pub fn is_finished(&self) -> bool { + self.num_rows == self.cur_row + } + + #[inline(always)] + /// Returns the cursor's current row, and advances the cursor to the next row + pub fn advance(&mut self) -> usize { + assert!(!self.is_finished()); + let t = self.cur_row; + self.cur_row += 1; + t + } + + /// Compares the sort key pointed to by this instance's row cursor with that of another + pub fn compare(&self, other: &SortKeyCursor) -> error::Result { + if self.sort_columns.len() != other.sort_columns.len() { + return Err(DataFusionError::Internal(format!( + "SortKeyCursors had inconsistent column counts: {} vs {}", + self.sort_columns.len(), + other.sort_columns.len() + ))); + } + + if self.sort_columns.len() != self.sort_options.len() { + return Err(DataFusionError::Internal(format!( + "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", + self.sort_columns.len(), + self.sort_options.len() + ))); + } + + let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self + .sort_columns + .iter() + .zip(other.sort_columns.iter()) + .zip(self.sort_options.iter()) + .collect::>(); + + self.init_cmp_if_needed(other, &zipped)?; + let map = self.batch_comparators.read(); + let cmp = map.get(&other.batch_id).ok_or_else(|| { + DataFusionError::Execution(format!( + "Failed to find comparator for {} cmp {}", + self.batch_id, other.batch_id + )) + })?; + + for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { + match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { + (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), + (false, true) => return Ok(Ordering::Greater), + (true, false) if sort_options.nulls_first => { + return Ok(Ordering::Greater) + } + (true, false) => return Ok(Ordering::Less), + (false, false) => {} + (true, true) => match cmp[i](self.cur_row, other.cur_row) { + Ordering::Equal => {} + o if sort_options.descending => return Ok(o.reverse()), + o => return Ok(o), + }, + } + } + + // Break ties using stream_idx to ensure a predictable + // ordering of rows when comparing equal streams. + Ok(self.stream_idx.cmp(&other.stream_idx)) + } + + /// Initialize a collection of comparators for comparing + /// columnar arrays of this cursor and "other" if needed. + fn init_cmp_if_needed( + &self, + other: &SortKeyCursor, + zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], + ) -> Result<()> { + let hm = self.batch_comparators.read(); + if !hm.contains_key(&other.batch_id) { + drop(hm); + let mut map = self.batch_comparators.write(); + let cmp = map + .borrow_mut() + .entry(other.batch_id) + .or_insert_with(|| Vec::with_capacity(other.sort_columns.len())); + + for (i, ((l, r), _)) in zipped.iter().enumerate() { + if i >= cmp.len() { + // initialise comparators + cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?); + } + } + } + Ok(()) + } +} + +impl Ord for SortKeyCursor { + /// Needed by min-heap comparison and reverse the order at the same time. + fn cmp(&self, other: &Self) -> Ordering { + other.compare(self).unwrap() + } +} + +impl PartialEq for SortKeyCursor { + fn eq(&self, other: &Self) -> bool { + other.compare(self).unwrap() == Ordering::Equal + } +} + +impl Eq for SortKeyCursor {} + +impl PartialOrd for SortKeyCursor { + fn partial_cmp(&self, other: &Self) -> Option { + other.compare(self).ok() + } +} diff --git a/datafusion/core/src/physical_plan/sorts/index.rs b/datafusion/core/src/physical_plan/sorts/index.rs new file mode 100644 index 000000000000..3b45c6d38770 --- /dev/null +++ b/datafusion/core/src/physical_plan/sorts/index.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// A `RowIndex` identifies a specific row in a logical stream. +/// +/// Each stream is identified by an `stream_idx` and is formed from a +/// sequence of RecordBatches batches, each of which is identified by +/// a unique `batch_idx` within that stream. +/// +/// This is used by `SortPreservingMergeStream` to identify which +/// the order of the tuples in the final sorted output stream. +/// +/// ```text +/// ┌────┐ ┌────┐ ┌────┐ RecordBatch +/// │ │ │ │ │ │ +/// │ C1 │ │... │ │ CN │◀─────── (batch_idx = 0) +/// │ │ │ │ │ │ +/// └────┘ └────┘ └────┘ +/// ┌────┐ ┌────┐ ┌────┐ RecordBatch +/// │ │ │ │ │ │ +/// │ C1 │ │... │ │ CN │◀─────── (batch_idx = 1) +/// │ │ │ │ │ │ +/// └────┘ └────┘ └────┘ +/// ┌────┐ +/// │ │ ... +/// │ C1 │ +/// │ │ ┌────┐ RecordBatch +/// └────┘ │ │ +/// │ CN │◀────── (batch_idx = M-1) +/// │ │ +/// └────┘ +/// +///"Stream"s each with Stream N has M +/// a potentially RecordBatches +///different number of +/// RecordBatches +/// ``` +#[derive(Debug, Clone)] +pub struct RowIndex { + /// The index of the stream (uniquely identifies the stream) + pub stream_idx: usize, + /// The index of the batch within the stream's VecDequeue. + pub batch_idx: usize, + /// The row index within the batch + pub row_idx: usize, +} diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs b/datafusion/core/src/physical_plan/sorts/mod.rs index 8d499be3a288..db6ab5c604e2 100644 --- a/datafusion/core/src/physical_plan/sorts/mod.rs +++ b/datafusion/core/src/physical_plan/sorts/mod.rs @@ -17,208 +17,16 @@ //! Sort functionalities -use crate::error; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{PhysicalExpr, SendableRecordBatchStream}; -use arrow::array::{ArrayRef, DynComparator}; -use arrow::compute::SortOptions; -use arrow::record_batch::RecordBatch; -use hashbrown::HashMap; -use parking_lot::RwLock; -use std::borrow::BorrowMut; -use std::cmp::Ordering; +use crate::physical_plan::SendableRecordBatchStream; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +mod cursor; +mod index; pub mod sort; pub mod sort_preserving_merge; -/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of -/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. -/// -/// Additionally it maintains a row cursor that can be advanced through the rows -/// of the provided `RecordBatch` -/// -/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to -/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores -/// a row comparator for each other cursor that it is compared to. -struct SortKeyCursor { - stream_idx: usize, - sort_columns: Vec, - cur_row: usize, - num_rows: usize, - - // An id uniquely identifying the record batch scanned by this cursor. - batch_id: usize, - - // A collection of comparators that compare rows in this cursor's batch to - // the cursors in other batches. Other batches are uniquely identified by - // their batch_idx. - batch_comparators: RwLock>>, - sort_options: Arc>, -} - -impl<'a> std::fmt::Debug for SortKeyCursor { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("sort_columns", &self.sort_columns) - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .field("batch_id", &self.batch_id) - .field("batch_comparators", &"") - .finish() - } -} - -impl SortKeyCursor { - fn new( - stream_idx: usize, - batch_id: usize, - batch: &RecordBatch, - sort_key: &[Arc], - sort_options: Arc>, - ) -> error::Result { - let sort_columns = sort_key - .iter() - .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows()))) - .collect::>()?; - Ok(Self { - stream_idx, - cur_row: 0, - num_rows: batch.num_rows(), - sort_columns, - batch_id, - batch_comparators: RwLock::new(HashMap::new()), - sort_options, - }) - } - - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row - } - - fn advance(&mut self) -> usize { - assert!(!self.is_finished()); - let t = self.cur_row; - self.cur_row += 1; - t - } - - /// Compares the sort key pointed to by this instance's row cursor with that of another - fn compare(&self, other: &SortKeyCursor) -> error::Result { - if self.sort_columns.len() != other.sort_columns.len() { - return Err(DataFusionError::Internal(format!( - "SortKeyCursors had inconsistent column counts: {} vs {}", - self.sort_columns.len(), - other.sort_columns.len() - ))); - } - - if self.sort_columns.len() != self.sort_options.len() { - return Err(DataFusionError::Internal(format!( - "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", - self.sort_columns.len(), - self.sort_options.len() - ))); - } - - let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self - .sort_columns - .iter() - .zip(other.sort_columns.iter()) - .zip(self.sort_options.iter()) - .collect::>(); - - self.init_cmp_if_needed(other, &zipped)?; - let map = self.batch_comparators.read(); - let cmp = map.get(&other.batch_id).ok_or_else(|| { - DataFusionError::Execution(format!( - "Failed to find comparator for {} cmp {}", - self.batch_id, other.batch_id - )) - })?; - - for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { - match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { - (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), - (false, true) => return Ok(Ordering::Greater), - (true, false) if sort_options.nulls_first => { - return Ok(Ordering::Greater) - } - (true, false) => return Ok(Ordering::Less), - (false, false) => {} - (true, true) => match cmp[i](self.cur_row, other.cur_row) { - Ordering::Equal => {} - o if sort_options.descending => return Ok(o.reverse()), - o => return Ok(o), - }, - } - } - - // Break ties using stream_idx to ensure a predictable - // ordering of rows when comparing equal streams. - Ok(self.stream_idx.cmp(&other.stream_idx)) - } - - /// Initialize a collection of comparators for comparing - /// columnar arrays of this cursor and "other" if needed. - fn init_cmp_if_needed( - &self, - other: &SortKeyCursor, - zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], - ) -> Result<()> { - let hm = self.batch_comparators.read(); - if !hm.contains_key(&other.batch_id) { - drop(hm); - let mut map = self.batch_comparators.write(); - let cmp = map - .borrow_mut() - .entry(other.batch_id) - .or_insert_with(|| Vec::with_capacity(other.sort_columns.len())); - - for (i, ((l, r), _)) in zipped.iter().enumerate() { - if i >= cmp.len() { - // initialise comparators - cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?); - } - } - } - Ok(()) - } -} - -impl Ord for SortKeyCursor { - /// Needed by min-heap comparison and reverse the order at the same time. - fn cmp(&self, other: &Self) -> Ordering { - other.compare(self).unwrap() - } -} - -impl PartialEq for SortKeyCursor { - fn eq(&self, other: &Self) -> bool { - other.compare(self).unwrap() == Ordering::Equal - } -} - -impl Eq for SortKeyCursor {} - -impl PartialOrd for SortKeyCursor { - fn partial_cmp(&self, other: &Self) -> Option { - other.compare(self).ok() - } -} - -/// A `RowIndex` identifies a specific row from those buffered -/// by a `SortPreservingMergeStream` -#[derive(Debug, Clone)] -struct RowIndex { - /// The index of the stream - stream_idx: usize, - /// The index of the batch within the stream's VecDequeue. - batch_idx: usize, - /// The row index - row_idx: usize, -} +pub use cursor::SortKeyCursor; +pub use index::RowIndex; pub(crate) struct SortedStream { stream: SendableRecordBatchStream, diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 515300eff5ce..f7ce73834b66 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -551,7 +551,7 @@ impl SortPreservingMergeStream { match self.min_heap.pop() { Some(mut cursor) => { - let stream_idx = cursor.stream_idx; + let stream_idx = cursor.stream_idx(); let batch_idx = self.batches[stream_idx].len() - 1; let row_idx = cursor.advance(); diff --git a/datafusion/core/tests/sort_key_cursor.rs b/datafusion/core/tests/sort_key_cursor.rs new file mode 100644 index 000000000000..7672ea577b4b --- /dev/null +++ b/datafusion/core/tests/sort_key_cursor.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains tests for SortKeyCursor + +use std::{cmp::Ordering, sync::Arc}; + +use arrow::{array::Int64Array, compute::SortOptions, record_batch::RecordBatch}; +use datafusion::physical_plan::sorts::{RowIndex, SortKeyCursor}; +use datafusion_physical_expr::expressions::col; + +#[test] +fn test_single_column() { + let batch1 = int64_batch(vec![Some(1), Some(2), Some(5), Some(6)]); + let batch2 = int64_batch(vec![Some(3), Some(4), Some(8), Some(9)]); + + let mut cursor1 = CursorBuilder::new(batch1) + .with_stream_idx(11) + .with_batch_id(0) + .build(); + + let mut cursor2 = CursorBuilder::new(batch2) + .with_stream_idx(22) + .with_batch_id(0) + .build(); + + let expected = vec![ + "11: (0, 0)", + "11: (0, 1)", + "22: (0, 0)", + "22: (0, 1)", + "11: (0, 2)", + "11: (0, 3)", + "22: (0, 2)", + "22: (0, 3)", + ]; + + assert_indexes(expected, run(&mut cursor1, &mut cursor2)); +} + +#[test] +fn test_stable_compare() { + // Validate ties are broken by the lower stream idx to ensure stable sort + let batch1 = int64_batch(vec![Some(3), Some(4)]); + let batch2 = int64_batch(vec![Some(3)]); + + let cursor1 = CursorBuilder::new(batch1) + // higher stream index + .with_stream_idx(33) + .with_batch_id(0); + + let cursor2 = CursorBuilder::new(batch2) + // Lower stream index -- should always be first + .with_stream_idx(22) + .with_batch_id(0); + + let expected = vec!["22: (0, 0)", "33: (0, 0)", "33: (0, 1)"]; + + // Output should be the same, regardless of order + assert_indexes( + &expected, + run(&mut cursor1.clone().build(), &mut cursor2.clone().build()), + ); + assert_indexes(&expected, run(&mut cursor2.build(), &mut cursor1.build())); +} + +/// Runs the two cursors to completion, sorting them, and +/// returning the sorted order of rows that would have produced +fn run(cursor1: &mut SortKeyCursor, cursor2: &mut SortKeyCursor) -> Vec { + let mut indexes = vec![]; + loop { + println!( + "(cursor1.is_finished(), cursor2.is_finished()): ({}, {})", + cursor1.is_finished(), + cursor2.is_finished() + ); + + match (cursor1.is_finished(), cursor2.is_finished()) { + (true, true) => return indexes, + (true, false) => return drain(cursor2, indexes), + (false, true) => return drain(cursor1, indexes), + // both cursors have more rows + (false, false) => match cursor1.compare(cursor2).unwrap() { + Ordering::Less => { + indexes.push(advance(cursor1)); + } + Ordering::Equal => { + indexes.push(advance(cursor1)); + indexes.push(advance(cursor2)); + } + Ordering::Greater => { + indexes.push(advance(cursor2)); + } + }, + } + } +} + +// Advance the cursor and return the RowIndex created +fn advance(cursor: &mut SortKeyCursor) -> RowIndex { + let row_idx = cursor.advance(); + RowIndex { + stream_idx: cursor.stream_idx(), + batch_idx: cursor.batch_id(), + row_idx, + } +} + +// Drain remaining items in the cursor, appending result to indexes +fn drain(cursor: &mut SortKeyCursor, mut indexes: Vec) -> Vec { + while !cursor.is_finished() { + indexes.push(advance(cursor)) + } + indexes +} + +/// Return the values as an [`Int64Array`] single record batch, with +/// column "c1" +fn int64_batch(values: impl IntoIterator>) -> RecordBatch { + let array: Int64Array = values.into_iter().collect(); + RecordBatch::try_from_iter(vec![("c1", Arc::new(array) as _)]).unwrap() +} + +/// helper for creating cursors to test +#[derive(Debug, Clone)] +struct CursorBuilder { + batch: RecordBatch, + stream_idx: Option, + batch_id: Option, +} + +impl CursorBuilder { + fn new(batch: RecordBatch) -> Self { + Self { + batch, + stream_idx: None, + batch_id: None, + } + } + + /// Set the stream index + fn with_stream_idx(mut self, stream_idx: usize) -> Self { + self.stream_idx = Some(stream_idx); + self + } + + /// Set the stream index + fn with_batch_id(mut self, batch_id: usize) -> Self { + self.batch_id = Some(batch_id); + self + } + + fn build(self) -> SortKeyCursor { + let Self { + batch, + stream_idx, + batch_id, + } = self; + let c1 = col("c1", &batch.schema()).unwrap(); + let sort_key = vec![c1]; + + let sort_options = Arc::new(vec![SortOptions::default()]); + + SortKeyCursor::new( + stream_idx.expect("stream idx not set"), + batch_id.expect("batch id not set"), + &batch, + &sort_key, + sort_options, + ) + .unwrap() + } +} + +/// Compares [`RowIndex`]es with a vector of strings, the result of +/// pretty formatting the [`RowIndex`]es. +/// +/// Designed so that failure output can be directly copy/pasted +/// into the test code as expected results. +fn assert_indexes( + expected_indexes: impl IntoIterator>, + indexes: impl IntoIterator, +) { + let expected_lines: Vec<_> = expected_indexes + .into_iter() + .map(|s| s.as_ref().to_string()) + .collect(); + + let actual_lines = format_as_strings(indexes); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + +/// Formats an terator of RowIndexes into strings for comparisons +/// +/// ```text +/// stream: (batch, index) +/// ``` +/// +/// for example, +/// ```text +/// 1: (0, 2) +/// ``` +/// means "Stream 1, batch id 0, row index 2" +fn format_as_strings(indexes: impl IntoIterator) -> Vec { + indexes + .into_iter() + .map(|row_index| { + format!( + "{}: ({}, {})", + row_index.stream_idx, row_index.batch_idx, row_index.row_idx + ) + }) + .collect() +}