Skip to content

Commit afc9c9d

Browse files
[MINOR] Remove update state api from PartitionEvaluator (#6966)
* remove update_state api from partition_evaluator * Resolve linter errors * Simplifications * remove row_idx argument from evaluate * Simplifications * Update datafusion/expr/src/partition_evaluator.rs Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com> * Update comment Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com> * Update document * Use boolean operator instead of bitwise --------- Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
1 parent d316702 commit afc9c9d

File tree

8 files changed

+67
-163
lines changed

8 files changed

+67
-163
lines changed

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ fn get_random_function(
208208
vec![],
209209
),
210210
);
211+
window_fn_map.insert(
212+
"dense_rank",
213+
(
214+
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank),
215+
vec![],
216+
),
217+
);
211218
window_fn_map.insert(
212219
"lead",
213220
(

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use arrow_schema::DataType;
3232
use datafusion::{assert_batches_eq, prelude::SessionContext};
3333
use datafusion_common::{Result, ScalarValue};
3434
use datafusion_expr::{
35-
function::PartitionEvaluatorFactory, window_state::WindowAggState,
36-
PartitionEvaluator, ReturnTypeFunction, Signature, Volatility, WindowUDF,
35+
function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction,
36+
Signature, Volatility, WindowUDF,
3737
};
3838

3939
/// A query with a window function evaluated over the entire partition
@@ -195,7 +195,6 @@ async fn test_stateful_udwf() {
195195
&execute(&ctx, UNBOUNDED_WINDOW_QUERY).await.unwrap()
196196
);
197197
assert_eq!(test_state.evaluate_called(), 10);
198-
assert_eq!(test_state.update_state_called(), 10);
199198
assert_eq!(test_state.evaluate_all_called(), 0);
200199
}
201200

@@ -229,7 +228,6 @@ async fn test_stateful_udwf_bounded_window() {
229228
);
230229
// Evaluate and update_state is called for each input row
231230
assert_eq!(test_state.evaluate_called(), 10);
232-
assert_eq!(test_state.update_state_called(), 10);
233231
assert_eq!(test_state.evaluate_all_called(), 0);
234232
}
235233

@@ -388,8 +386,6 @@ struct TestState {
388386
evaluate_all_called: AtomicUsize,
389387
/// How many times was `evaluate` called?
390388
evaluate_called: AtomicUsize,
391-
/// How many times was `update_state` called?
392-
update_state_called: AtomicUsize,
393389
/// How many times was `evaluate_all_with_rank` called?
394390
evaluate_all_with_rank_called: AtomicUsize,
395391
/// should the functions say they use the window frame?
@@ -451,16 +447,6 @@ impl TestState {
451447
self.evaluate_called.fetch_add(1, Ordering::SeqCst);
452448
}
453449

454-
/// return the update_state_called counter
455-
fn update_state_called(&self) -> usize {
456-
self.update_state_called.load(Ordering::SeqCst)
457-
}
458-
459-
/// update the update_state_called counter
460-
fn inc_update_state_called(&self) {
461-
self.update_state_called.fetch_add(1, Ordering::SeqCst);
462-
}
463-
464450
/// return the evaluate_all_with_rank_called counter
465451
fn evaluate_all_with_rank_called(&self) -> usize {
466452
self.evaluate_all_with_rank_called.load(Ordering::SeqCst)
@@ -555,17 +541,6 @@ impl PartitionEvaluator for OddCounter {
555541
Ok(Arc::new(array))
556542
}
557543

558-
fn update_state(
559-
&mut self,
560-
_state: &WindowAggState,
561-
_idx: usize,
562-
_range_columns: &[ArrayRef],
563-
_sort_partition_points: &[Range<usize>],
564-
) -> Result<()> {
565-
self.test_state.inc_update_state_called();
566-
Ok(())
567-
}
568-
569544
fn supports_bounded_execution(&self) -> bool {
570545
self.test_state.supports_bounded_execution
571546
}

datafusion/expr/src/partition_evaluator.rs

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,10 @@ use crate::window_state::WindowAggState;
6969
/// capabilities described by [`supports_bounded_execution`],
7070
/// [`uses_window_frame`], and [`include_rank`],
7171
///
72-
/// # Stateless `PartitionEvaluator`s
73-
///
74-
/// In this case, `PartitionEvaluator` holds no state, and either
75-
/// [`evaluate_all`] or [`evaluate_all_with_rank`] is called with
76-
/// values for the entire partition.
77-
///
78-
/// # Stateful `PartitionEvaluator`s
79-
///
80-
/// In this case, [`Self::evaluate`] is called to calculate the window
81-
/// function incrementally for each new batch.
82-
///
83-
/// For example, when computing `ROW_NUMBER` incrementally,
84-
/// [`Self::evaluate`] will be called multiple times with
85-
/// different batches. For all batches after the first, the output
86-
/// `row_number` must start from last `row_number` produced for the
87-
/// previous batch. The previous row number is saved and restored as
88-
/// the state.
89-
///
9072
/// When implementing a new `PartitionEvaluator`, implement
9173
/// corresponding evaluator according to table below.
9274
///
75+
/// # Implementation Table
9376
///
9477
/// |[`uses_window_frame`]|[`supports_bounded_execution`]|[`include_rank`]|function_to_implement|
9578
/// |---|---|----|----|
@@ -105,25 +88,6 @@ use crate::window_state::WindowAggState;
10588
/// [`include_rank`]: Self::include_rank
10689
/// [`supports_bounded_execution`]: Self::supports_bounded_execution
10790
pub trait PartitionEvaluator: Debug + Send {
108-
/// Updates the internal state for window function
109-
///
110-
/// Only used for stateful evaluation
111-
///
112-
/// `state`: is useful to update internal state for window function.
113-
/// `idx`: is the index of last row for which result is calculated.
114-
/// `range_columns`: is the result of order by column values. It is used to calculate rank boundaries
115-
/// `sort_partition_points`: is the boundaries of each rank in the range_column. It is used to update rank.
116-
fn update_state(
117-
&mut self,
118-
_state: &WindowAggState,
119-
_idx: usize,
120-
_range_columns: &[ArrayRef],
121-
_sort_partition_points: &[Range<usize>],
122-
) -> Result<()> {
123-
// If we do not use state, update_state does nothing
124-
Ok(())
125-
}
126-
12791
/// When the window frame has a fixed beginning (e.g UNBOUNDED
12892
/// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and
12993
/// NTH_VALUE do not need the (unbounded) input once they have
@@ -220,7 +184,10 @@ pub trait PartitionEvaluator: Debug + Send {
220184
/// trait.
221185
///
222186
/// Returns a [`ScalarValue`] that is the value of the window
223-
/// function within `range` for the entire partition
187+
/// function within `range` for the entire partition. Argument
188+
/// `values` contains the evaluation result of function arguments
189+
/// and evaluation results of ORDER BY expressions. If function has a
190+
/// single argument, `values[1..]` will contain ORDER BY expression results.
224191
fn evaluate(
225192
&mut self,
226193
_values: &[ArrayRef],

datafusion/physical-expr/src/window/built_in.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use std::sync::Arc;
2323

2424
use super::BuiltInWindowFunctionExpr;
2525
use super::WindowExpr;
26-
use crate::window::window_expr::WindowFn;
26+
use crate::window::window_expr::{get_orderby_values, WindowFn};
2727
use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState};
2828
use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
2929
use arrow::array::{new_empty_array, ArrayRef};
@@ -101,14 +101,19 @@ impl WindowExpr for BuiltInWindowExpr {
101101
self.order_by.iter().map(|o| o.options).collect();
102102
let mut row_wise_results = vec![];
103103

104-
let (values, order_bys) = self.get_values_orderbys(batch)?;
104+
let mut values = self.evaluate_args(batch)?;
105+
let order_bys = get_orderby_values(self.order_by_columns(batch)?);
106+
let n_args = values.len();
107+
values.extend(order_bys);
108+
let order_bys_ref = &values[n_args..];
109+
105110
let mut window_frame_ctx =
106111
WindowFrameContext::new(self.window_frame.clone(), sort_options);
107112
let mut last_range = Range { start: 0, end: 0 };
108113
// We iterate on each row to calculate window frame range and and window function result
109114
for idx in 0..num_rows {
110115
let range = window_frame_ctx.calculate_range(
111-
&order_bys,
116+
order_bys_ref,
112117
&last_range,
113118
num_rows,
114119
idx,
@@ -119,11 +124,11 @@ impl WindowExpr for BuiltInWindowExpr {
119124
}
120125
ScalarValue::iter_to_array(row_wise_results.into_iter())
121126
} else if evaluator.include_rank() {
122-
let columns = self.sort_columns(batch)?;
127+
let columns = self.order_by_columns(batch)?;
123128
let sort_partition_points = evaluate_partition_ranges(num_rows, &columns)?;
124129
evaluator.evaluate_all_with_rank(num_rows, &sort_partition_points)
125130
} else {
126-
let (values, _) = self.get_values_orderbys(batch)?;
131+
let values = self.evaluate_args(batch)?;
127132
evaluator.evaluate_all(&values, num_rows)
128133
}
129134
}
@@ -157,18 +162,20 @@ impl WindowExpr for BuiltInWindowExpr {
157162
};
158163
let state = &mut window_state.state;
159164

160-
let (values, order_bys) =
161-
self.get_values_orderbys(&partition_batch_state.record_batch)?;
165+
let batch_ref = &partition_batch_state.record_batch;
166+
let mut values = self.evaluate_args(batch_ref)?;
167+
let order_bys = if evaluator.uses_window_frame() || evaluator.include_rank() {
168+
get_orderby_values(self.order_by_columns(batch_ref)?)
169+
} else {
170+
vec![]
171+
};
172+
let n_args = values.len();
173+
values.extend(order_bys);
174+
let order_bys_ref = &values[n_args..];
162175

163176
// We iterate on each row to perform a running calculation.
164177
let record_batch = &partition_batch_state.record_batch;
165178
let num_rows = record_batch.num_rows();
166-
let sort_partition_points = if evaluator.include_rank() {
167-
let columns = self.sort_columns(record_batch)?;
168-
evaluate_partition_ranges(num_rows, &columns)?
169-
} else {
170-
vec![]
171-
};
172179
let mut row_wise_results: Vec<ScalarValue> = vec![];
173180
for idx in state.last_calculated_index..num_rows {
174181
let frame_range = if evaluator.uses_window_frame() {
@@ -181,7 +188,7 @@ impl WindowExpr for BuiltInWindowExpr {
181188
)
182189
})
183190
.calculate_range(
184-
&order_bys,
191+
order_bys_ref,
185192
// Start search from the last range
186193
&state.window_frame_range,
187194
num_rows,
@@ -197,7 +204,6 @@ impl WindowExpr for BuiltInWindowExpr {
197204
}
198205
// Update last range
199206
state.window_frame_range = frame_range;
200-
evaluator.update_state(state, idx, &order_bys, &sort_partition_points)?;
201207
row_wise_results
202208
.push(evaluator.evaluate(&values, &state.window_frame_range)?);
203209
}

datafusion/physical-expr/src/window/lead_lag.rs

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@
1818
//! Defines physical expression for `lead` and `lag` that can evaluated
1919
//! at runtime during query execution
2020
21-
use crate::window::window_expr::LeadLagState;
2221
use crate::window::BuiltInWindowFunctionExpr;
2322
use crate::PhysicalExpr;
2423
use arrow::array::ArrayRef;
2524
use arrow::compute::cast;
2625
use arrow::datatypes::{DataType, Field};
2726
use datafusion_common::ScalarValue;
2827
use datafusion_common::{DataFusionError, Result};
29-
use datafusion_expr::window_state::WindowAggState;
3028
use datafusion_expr::PartitionEvaluator;
3129
use std::any::Any;
3230
use std::cmp::min;
@@ -105,7 +103,6 @@ impl BuiltInWindowFunctionExpr for WindowShift {
105103

106104
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
107105
Ok(Box::new(WindowShiftEvaluator {
108-
state: LeadLagState { idx: 0 },
109106
shift_offset: self.shift_offset,
110107
default_value: self.default_value.clone(),
111108
}))
@@ -124,7 +121,6 @@ impl BuiltInWindowFunctionExpr for WindowShift {
124121

125122
#[derive(Debug)]
126123
pub(crate) struct WindowShiftEvaluator {
127-
state: LeadLagState,
128124
shift_offset: i64,
129125
default_value: Option<ScalarValue>,
130126
}
@@ -179,17 +175,6 @@ fn shift_with_default_value(
179175
}
180176

181177
impl PartitionEvaluator for WindowShiftEvaluator {
182-
fn update_state(
183-
&mut self,
184-
_state: &WindowAggState,
185-
idx: usize,
186-
_range_columns: &[ArrayRef],
187-
_sort_partition_points: &[Range<usize>],
188-
) -> Result<()> {
189-
self.state.idx = idx;
190-
Ok(())
191-
}
192-
193178
fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
194179
if self.shift_offset > 0 {
195180
let offset = self.shift_offset as usize;
@@ -206,11 +191,18 @@ impl PartitionEvaluator for WindowShiftEvaluator {
206191
fn evaluate(
207192
&mut self,
208193
values: &[ArrayRef],
209-
_range: &Range<usize>,
194+
range: &Range<usize>,
210195
) -> Result<ScalarValue> {
211196
let array = &values[0];
212197
let dtype = array.data_type();
213-
let idx = self.state.idx as i64 - self.shift_offset;
198+
// LAG mode
199+
let idx = if self.shift_offset > 0 {
200+
range.end as i64 - self.shift_offset - 1
201+
} else {
202+
// LEAD mode
203+
range.start as i64 - self.shift_offset
204+
};
205+
214206
if idx < 0 || idx as usize >= array.len() {
215207
get_default_value(self.default_value.as_ref(), dtype)
216208
} else {

datafusion/physical-expr/src/window/nth_value.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,6 @@ pub(crate) struct NthValueEvaluator {
145145
}
146146

147147
impl PartitionEvaluator for NthValueEvaluator {
148-
fn update_state(
149-
&mut self,
150-
state: &WindowAggState,
151-
_idx: usize,
152-
_range_columns: &[ArrayRef],
153-
_sort_partition_points: &[Range<usize>],
154-
) -> Result<()> {
155-
// If we do not use state, update_state does nothing
156-
self.state.range.clone_from(&state.window_frame_range);
157-
Ok(())
158-
}
159-
160148
/// When the window frame has a fixed beginning (e.g UNBOUNDED
161149
/// PRECEDING), for some functions such as FIRST_VALUE, LAST_VALUE and
162150
/// NTH_VALUE we can memoize result. Once result is calculated it

datafusion/physical-expr/src/window/rank.rs

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ use arrow::array::{Float64Array, UInt64Array};
2626
use arrow::datatypes::{DataType, Field};
2727
use datafusion_common::utils::get_row_at_idx;
2828
use datafusion_common::{DataFusionError, Result, ScalarValue};
29-
use datafusion_expr::window_state::WindowAggState;
3029
use datafusion_expr::PartitionEvaluator;
3130
use std::any::Any;
3231
use std::iter;
@@ -116,39 +115,26 @@ pub(crate) struct RankEvaluator {
116115
}
117116

118117
impl PartitionEvaluator for RankEvaluator {
119-
fn update_state(
118+
/// Evaluates the window function inside the given range.
119+
fn evaluate(
120120
&mut self,
121-
state: &WindowAggState,
122-
idx: usize,
123-
range_columns: &[ArrayRef],
124-
sort_partition_points: &[Range<usize>],
125-
) -> Result<()> {
126-
// find range inside `sort_partition_points` containing `idx`
127-
let chunk_idx = sort_partition_points
128-
.iter()
129-
.position(|elem| elem.start <= idx && idx < elem.end)
130-
.ok_or_else(|| {
131-
DataFusionError::Execution(
132-
"Expects sort_partition_points to contain idx".to_string(),
133-
)
134-
})?;
135-
let chunk = &sort_partition_points[chunk_idx];
136-
let last_rank_data = get_row_at_idx(range_columns, chunk.end - 1)?;
121+
values: &[ArrayRef],
122+
range: &Range<usize>,
123+
) -> Result<ScalarValue> {
124+
let row_idx = range.start;
125+
// There is no argument, values are order by column values (where rank is calculated)
126+
let range_columns = values;
127+
let last_rank_data = get_row_at_idx(range_columns, row_idx)?;
137128
let empty = self.state.last_rank_data.is_empty();
138129
if empty || self.state.last_rank_data != last_rank_data {
139130
self.state.last_rank_data = last_rank_data;
140-
self.state.last_rank_boundary = state.offset_pruned_rows + chunk.start;
141-
self.state.n_rank = 1 + if empty { chunk_idx } else { self.state.n_rank };
131+
self.state.last_rank_boundary += self.state.current_group_count;
132+
self.state.current_group_count = 1;
133+
self.state.n_rank += 1;
134+
} else {
135+
// data is still in the same rank
136+
self.state.current_group_count += 1;
142137
}
143-
Ok(())
144-
}
145-
146-
/// evaluate window function result inside given range
147-
fn evaluate(
148-
&mut self,
149-
_values: &[ArrayRef],
150-
_range: &Range<usize>,
151-
) -> Result<ScalarValue> {
152138
match self.rank_type {
153139
RankType::Basic => Ok(ScalarValue::UInt64(Some(
154140
self.state.last_rank_boundary as u64 + 1,

0 commit comments

Comments
 (0)