Skip to content

Commit dbc780a

Browse files
committed
Move SortKeyCursor and RowIndex into modules, add sort_key_cursor test
1 parent 2aa1eea commit dbc780a

File tree

5 files changed

+386
-198
lines changed

5 files changed

+386
-198
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
use crate::error;
2+
use crate::error::{DataFusionError, Result};
3+
use crate::physical_plan::PhysicalExpr;
4+
use arrow::array::{ArrayRef, DynComparator};
5+
use arrow::compute::SortOptions;
6+
use arrow::record_batch::RecordBatch;
7+
use hashbrown::HashMap;
8+
use parking_lot::RwLock;
9+
use std::borrow::BorrowMut;
10+
use std::cmp::Ordering;
11+
use std::sync::Arc;
12+
13+
/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
14+
/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
15+
///
16+
/// Additionally it maintains a row cursor that can be advanced through the rows
17+
/// of the provided `RecordBatch`
18+
///
19+
/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
20+
/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
21+
/// a row comparator for each other cursor that it is compared to.
22+
pub struct SortKeyCursor {
23+
stream_idx: usize,
24+
sort_columns: Vec<ArrayRef>,
25+
cur_row: usize,
26+
num_rows: usize,
27+
28+
// An id uniquely identifying the record batch scanned by this cursor.
29+
batch_id: usize,
30+
31+
// A collection of comparators that compare rows in this cursor's batch to
32+
// the cursors in other batches. Other batches are uniquely identified by
33+
// their batch_idx.
34+
batch_comparators: RwLock<HashMap<usize, Vec<DynComparator>>>,
35+
sort_options: Arc<Vec<SortOptions>>,
36+
}
37+
38+
impl<'a> std::fmt::Debug for SortKeyCursor {
39+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40+
f.debug_struct("SortKeyCursor")
41+
.field("sort_columns", &self.sort_columns)
42+
.field("cur_row", &self.cur_row)
43+
.field("num_rows", &self.num_rows)
44+
.field("batch_id", &self.batch_id)
45+
.field("batch_comparators", &"<FUNC>")
46+
.finish()
47+
}
48+
}
49+
50+
impl SortKeyCursor {
51+
/// Create a new SortKeyCursor
52+
pub fn new(
53+
stream_idx: usize,
54+
batch_id: usize,
55+
batch: &RecordBatch,
56+
sort_key: &[Arc<dyn PhysicalExpr>],
57+
sort_options: Arc<Vec<SortOptions>>,
58+
) -> error::Result<Self> {
59+
let sort_columns = sort_key
60+
.iter()
61+
.map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
62+
.collect::<error::Result<_>>()?;
63+
Ok(Self {
64+
stream_idx,
65+
cur_row: 0,
66+
num_rows: batch.num_rows(),
67+
sort_columns,
68+
batch_id,
69+
batch_comparators: RwLock::new(HashMap::new()),
70+
sort_options,
71+
})
72+
}
73+
74+
#[inline(always)]
75+
/// Return the stream index of this cursor
76+
pub fn stream_idx(&self) -> usize {
77+
self.stream_idx
78+
}
79+
80+
#[inline(always)]
81+
/// Return the batch id of this cursor
82+
pub fn batch_id(&self) -> usize {
83+
self.batch_id
84+
}
85+
86+
#[inline(always)]
87+
/// Return true if the stream is finished
88+
pub fn is_finished(&self) -> bool {
89+
self.num_rows == self.cur_row
90+
}
91+
92+
#[inline(always)]
93+
/// Returns the cursor's current row, and advances the cursor to the next row
94+
pub fn advance(&mut self) -> usize {
95+
assert!(!self.is_finished());
96+
let t = self.cur_row;
97+
self.cur_row += 1;
98+
t
99+
}
100+
101+
/// Compares the sort key pointed to by this instance's row cursor with that of another
102+
pub fn compare(&self, other: &SortKeyCursor) -> error::Result<Ordering> {
103+
if self.sort_columns.len() != other.sort_columns.len() {
104+
return Err(DataFusionError::Internal(format!(
105+
"SortKeyCursors had inconsistent column counts: {} vs {}",
106+
self.sort_columns.len(),
107+
other.sort_columns.len()
108+
)));
109+
}
110+
111+
if self.sort_columns.len() != self.sort_options.len() {
112+
return Err(DataFusionError::Internal(format!(
113+
"Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}",
114+
self.sort_columns.len(),
115+
self.sort_options.len()
116+
)));
117+
}
118+
119+
let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self
120+
.sort_columns
121+
.iter()
122+
.zip(other.sort_columns.iter())
123+
.zip(self.sort_options.iter())
124+
.collect::<Vec<_>>();
125+
126+
self.init_cmp_if_needed(other, &zipped)?;
127+
let map = self.batch_comparators.read();
128+
let cmp = map.get(&other.batch_id).ok_or_else(|| {
129+
DataFusionError::Execution(format!(
130+
"Failed to find comparator for {} cmp {}",
131+
self.batch_id, other.batch_id
132+
))
133+
})?;
134+
135+
for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
136+
match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
137+
(false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
138+
(false, true) => return Ok(Ordering::Greater),
139+
(true, false) if sort_options.nulls_first => {
140+
return Ok(Ordering::Greater)
141+
}
142+
(true, false) => return Ok(Ordering::Less),
143+
(false, false) => {}
144+
(true, true) => match cmp[i](self.cur_row, other.cur_row) {
145+
Ordering::Equal => {}
146+
o if sort_options.descending => return Ok(o.reverse()),
147+
o => return Ok(o),
148+
},
149+
}
150+
}
151+
152+
// Break ties using stream_idx to ensure a predictable
153+
// ordering of rows when comparing equal streams.
154+
Ok(self.stream_idx.cmp(&other.stream_idx))
155+
}
156+
157+
/// Initialize a collection of comparators for comparing
158+
/// columnar arrays of this cursor and "other" if needed.
159+
fn init_cmp_if_needed(
160+
&self,
161+
other: &SortKeyCursor,
162+
zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)],
163+
) -> Result<()> {
164+
let hm = self.batch_comparators.read();
165+
if !hm.contains_key(&other.batch_id) {
166+
drop(hm);
167+
let mut map = self.batch_comparators.write();
168+
let cmp = map
169+
.borrow_mut()
170+
.entry(other.batch_id)
171+
.or_insert_with(|| Vec::with_capacity(other.sort_columns.len()));
172+
173+
for (i, ((l, r), _)) in zipped.iter().enumerate() {
174+
if i >= cmp.len() {
175+
// initialise comparators
176+
cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?);
177+
}
178+
}
179+
}
180+
Ok(())
181+
}
182+
}
183+
184+
impl Ord for SortKeyCursor {
185+
/// Needed by min-heap comparison and reverse the order at the same time.
186+
fn cmp(&self, other: &Self) -> Ordering {
187+
other.compare(self).unwrap()
188+
}
189+
}
190+
191+
impl PartialEq for SortKeyCursor {
192+
fn eq(&self, other: &Self) -> bool {
193+
other.compare(self).unwrap() == Ordering::Equal
194+
}
195+
}
196+
197+
impl Eq for SortKeyCursor {}
198+
199+
impl PartialOrd for SortKeyCursor {
200+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
201+
other.compare(self).ok()
202+
}
203+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/// A `RowIndex` identifies a specific row in a logical stream.
2+
///
3+
/// Each stream is identified by an `stream_idx` and is formed from a
4+
/// sequence of RecordBatches batches, each of which is identified by
5+
/// a unique `batch_idx` within that stream.
6+
///
7+
/// This is used by `SortPreservingMergeStream` to identify which
8+
/// the order of the tuples in the final sorted output stream.
9+
///
10+
/// ```text
11+
/// ┌────┐ ┌────┐ ┌────┐ RecordBatch
12+
/// │ │ │ │ │ │
13+
/// │ C1 │ │... │ │ CN │◀─────── (batch_idx = 0)
14+
/// │ │ │ │ │ │
15+
/// └────┘ └────┘ └────┘
16+
///
17+
/// ┌────┐ ┌────┐ ┌────┐ RecordBatch
18+
/// │ │ │ │ │ │
19+
/// │ C1 │ │... │ │ CN │◀─────── (batch_idx = 1)
20+
/// │ │ │ │ │ │
21+
/// └────┘ └────┘ └────┘
22+
///
23+
/// ...
24+
///
25+
/// ┌────┐ ┌────┐ ┌────┐ RecordBatch
26+
/// │ │ │ │ │ │
27+
/// │ C1 │ │... │ │ CN │◀────── (batch_idx = N-1)
28+
/// │ │ │ │ │ │
29+
/// └────┘ └────┘ └────┘
30+
///
31+
/// "Stream"
32+
/// of N RecordBatches
33+
/// ```
34+
#[derive(Debug, Clone)]
35+
pub struct RowIndex {
36+
/// The index of the stream (uniquely identifies the stream)
37+
pub stream_idx: usize,
38+
/// The index of the batch within the stream's VecDequeue.
39+
pub batch_idx: usize,
40+
/// The row index within the batch
41+
pub row_idx: usize,
42+
}

0 commit comments

Comments
 (0)