Skip to content

Commit 999866e

Browse files
committed
Return error rather than wrong results when aggregate without retract_batch is used as a sliding accumulator
1 parent e833914 commit 999866e

File tree

4 files changed

+104
-74
lines changed

4 files changed

+104
-74
lines changed

datafusion/core/src/physical_plan/udaf.rs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub fn create_aggregate_expr(
4141
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
4242
input_schema: &Schema,
4343
name: impl Into<String>,
44-
) -> Result<Arc<AggregateFunctionExpr>> {
44+
) -> Result<Arc<dyn AggregateExpr>> {
4545
let input_exprs_types = input_phy_exprs
4646
.iter()
4747
.map(|arg| arg.data_type(input_schema))
@@ -70,11 +70,6 @@ impl AggregateFunctionExpr {
7070
pub fn fun(&self) -> &AggregateUDF {
7171
&self.fun
7272
}
73-
74-
/// Returns true if this can support sliding accumulators
75-
pub fn retractable(&self) -> Result<bool> {
76-
Ok((self.fun.accumulator)(&self.data_type)?.supports_retract_batch())
77-
}
7873
}
7974

8075
impl AggregateExpr for AggregateFunctionExpr {
@@ -114,12 +109,54 @@ impl AggregateExpr for AggregateFunctionExpr {
114109
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
115110
let accumulator = (self.fun.accumulator)(&self.data_type)?;
116111

112+
// Accumulators that have window frame startings different
113+
// than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
114+
// implement retract_batch method in order to run correctly
115+
// currently in DataFusion.
116+
//
117+
// If this `retract_batches` is not present, there is no way
118+
// to calculate result correctly. For example, the query
119+
//
120+
// ```sql
121+
// SELECT
122+
// SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
123+
// FROM
124+
// t
125+
// ```
126+
//
127+
// 1. First sum value will be the sum of rows between `[0, 1)`,
128+
//
129+
// 2. Second sum value will be the sum of rows between `[0, 2)`
130+
//
131+
// 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
132+
//
133+
// Since the accumulator keeps the running sum:
134+
//
135+
// 1. First sum we add to the state sum value between `[0, 1)`
136+
//
137+
// 2. Second sum we add to the state sum value between `[1, 2)`
138+
// (`[0, 1)` is already in the state sum, hence running sum will
139+
// cover `[0, 2)` range)
140+
//
141+
// 3. Third sum we add to the state sum value between `[2, 3)`
142+
// (`[0, 2)` is already in the state sum). Also we need to
143+
// retract values between `[0, 1)` by this way we can obtain sum
144+
// between [1, 3) which is indeed the apropriate range.
145+
//
146+
// When we use `UNBOUNDED PRECEDING` in the query starting
147+
// index will always be 0 for the desired range, and hence the
148+
// `retract_batch` method will not be called. In this case
149+
// having retract_batch is not a requirement.
150+
//
151+
// This approach is a a bit different than window function
152+
// approach. In window function (when they use a window frame)
153+
// they get all the desired range during evaluation.
117154
if !accumulator.supports_retract_batch() {
118-
return Err(DataFusionError::Internal(
119-
format!(
120-
"Can't make sliding accumulator because retractable_accumulator not available for {}",
121-
self.name)
122-
));
155+
return Err(DataFusionError::NotImplemented(format!(
156+
"Aggregate can not be used as a sliding accumulator because \
157+
`retract_batch` is not implemented: {}",
158+
self.name
159+
)));
123160
}
124161
Ok(accumulator)
125162
}

datafusion/core/src/physical_plan/windows/mod.rs

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ use datafusion_expr::{
3333
window_function::{BuiltInWindowFunction, WindowFunction},
3434
WindowFrame,
3535
};
36-
use datafusion_physical_expr::window::{
37-
BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr,
36+
use datafusion_physical_expr::{
37+
window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr},
38+
AggregateExpr,
3839
};
3940
use std::borrow::Borrow;
4041
use std::convert::TryInto;
@@ -64,28 +65,16 @@ pub fn create_window_expr(
6465
window_frame: Arc<WindowFrame>,
6566
input_schema: &Schema,
6667
) -> Result<Arc<dyn WindowExpr>> {
67-
// Is there a potentially unlimited sized window frame?
68-
let unbounded_window = window_frame.start_bound.is_unbounded();
69-
7068
Ok(match fun {
7169
WindowFunction::AggregateFunction(fun) => {
7270
let aggregate =
7371
aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?;
74-
if !unbounded_window {
75-
Arc::new(SlidingAggregateWindowExpr::new(
76-
aggregate,
77-
partition_by,
78-
order_by,
79-
window_frame,
80-
))
81-
} else {
82-
Arc::new(PlainAggregateWindowExpr::new(
83-
aggregate,
84-
partition_by,
85-
order_by,
86-
window_frame,
87-
))
88-
}
72+
window_expr_from_aggregate_expr(
73+
partition_by,
74+
order_by,
75+
window_frame,
76+
aggregate,
77+
)
8978
}
9079
WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new(
9180
create_built_in_window_expr(fun, args, input_schema, name)?,
@@ -96,26 +85,43 @@ pub fn create_window_expr(
9685
WindowFunction::AggregateUDF(fun) => {
9786
let aggregate =
9887
udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?;
99-
100-
if !unbounded_window && aggregate.retractable()? {
101-
Arc::new(SlidingAggregateWindowExpr::new(
102-
aggregate,
103-
partition_by,
104-
order_by,
105-
window_frame,
106-
))
107-
} else {
108-
Arc::new(PlainAggregateWindowExpr::new(
109-
aggregate,
110-
partition_by,
111-
order_by,
112-
window_frame,
113-
))
114-
}
88+
window_expr_from_aggregate_expr(
89+
partition_by,
90+
order_by,
91+
window_frame,
92+
aggregate,
93+
)
11594
}
11695
})
11796
}
11897

98+
/// Creates an appropriate [`WindowExpr`] based on the window frame and
99+
fn window_expr_from_aggregate_expr(
100+
partition_by: &[Arc<dyn PhysicalExpr>],
101+
order_by: &[PhysicalSortExpr],
102+
window_frame: Arc<WindowFrame>,
103+
aggregate: Arc<dyn AggregateExpr>,
104+
) -> Arc<dyn WindowExpr> {
105+
// Is there a potentially unlimited sized window frame?
106+
let unbounded_window = window_frame.start_bound.is_unbounded();
107+
108+
if !unbounded_window {
109+
Arc::new(SlidingAggregateWindowExpr::new(
110+
aggregate,
111+
partition_by,
112+
order_by,
113+
window_frame,
114+
))
115+
} else {
116+
Arc::new(PlainAggregateWindowExpr::new(
117+
aggregate,
118+
partition_by,
119+
order_by,
120+
window_frame,
121+
))
122+
}
123+
}
124+
119125
fn get_scalar_value_from_args(
120126
args: &[Arc<dyn PhysicalExpr>],
121127
index: usize,

datafusion/core/tests/user_defined_aggregates.rs

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use datafusion::{
4040
prelude::SessionContext,
4141
scalar::ScalarValue,
4242
};
43-
use datafusion_common::{cast::as_primitive_array, DataFusionError};
43+
use datafusion_common::{assert_contains, cast::as_primitive_array, DataFusionError};
4444

4545
/// Test to show the contents of the setup
4646
#[tokio::test]
@@ -58,7 +58,7 @@ async fn test_setup() {
5858
"| 5.0 | 1970-01-01T00:00:00.000005 |",
5959
"+-------+----------------------------+",
6060
];
61-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
61+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
6262
}
6363

6464
/// Basic user defined aggregate
@@ -74,7 +74,7 @@ async fn test_udaf() {
7474
"| 1970-01-01T00:00:00.000019 |",
7575
"+----------------------------+",
7676
];
77-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
77+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
7878
// normal aggregates call update_batch
7979
assert!(test_state.update_batch());
8080
assert!(!test_state.retract_batch());
@@ -96,7 +96,7 @@ async fn test_udaf_as_window() {
9696
"| 1970-01-01T00:00:00.000019 |",
9797
"+----------------------------+",
9898
];
99-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
99+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
100100
// aggregate over the entire window function call update_batch
101101
assert!(test_state.update_batch());
102102
assert!(!test_state.retract_batch());
@@ -118,35 +118,23 @@ async fn test_udaf_as_window_with_frame() {
118118
"| 1970-01-01T00:00:00.000010 |",
119119
"+----------------------------+",
120120
];
121-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
121+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
122122
// user defined aggregates with window frame should be calling retract batch
123123
assert!(test_state.update_batch());
124124
assert!(test_state.retract_batch());
125125
}
126126

127127
/// Ensure that User defined aggregate used as a window function with a window
128-
/// frame, but that does not implement retract_batch, does not error
128+
/// frame, but that does not implement retract_batch, returns an error
129129
#[tokio::test]
130130
async fn test_udaf_as_window_with_frame_without_retract_batch() {
131131
let test_state = Arc::new(TestState::new().with_error_on_retract_batch());
132132

133-
let TestContext { ctx, test_state } = TestContext::new_with_test_state(test_state);
133+
let TestContext { ctx, test_state: _ } = TestContext::new_with_test_state(test_state);
134134
let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t";
135-
// TODO: It is not clear why this is a different value than when retract batch is used
136-
let expected = vec![
137-
"+----------------------------+",
138-
"| time_sum |",
139-
"+----------------------------+",
140-
"| 1970-01-01T00:00:00.000005 |",
141-
"| 1970-01-01T00:00:00.000009 |",
142-
"| 1970-01-01T00:00:00.000014 |",
143-
"| 1970-01-01T00:00:00.000019 |",
144-
"| 1970-01-01T00:00:00.000019 |",
145-
"+----------------------------+",
146-
];
147-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
148-
assert!(test_state.update_batch());
149-
assert!(!test_state.retract_batch());
135+
// Note if this query ever does start working
136+
let err = execute(&ctx, sql).await.unwrap_err();
137+
assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\"");
150138
}
151139

152140
/// Basic query for with a udaf returning a structure
@@ -161,7 +149,7 @@ async fn test_udaf_returning_struct() {
161149
"| {value: 2.0, time: 1970-01-01T00:00:00.000002} |",
162150
"+------------------------------------------------+",
163151
];
164-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
152+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
165153
}
166154

167155
/// Demonstrate extracting the fields from a structure using a subquery
@@ -176,11 +164,11 @@ async fn test_udaf_returning_struct_subquery() {
176164
"| 2.0 | 1970-01-01T00:00:00.000002 |",
177165
"+-----------------+----------------------------+",
178166
];
179-
assert_batches_eq!(expected, &execute(&ctx, sql).await);
167+
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());
180168
}
181169

182-
async fn execute(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
183-
ctx.sql(sql).await.unwrap().collect().await.unwrap()
170+
async fn execute(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
171+
ctx.sql(sql).await?.collect().await
184172
}
185173

186174
/// Returns an context with a table "t" and the "first" and "time_sum"

datafusion/proto/src/physical_plan/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ impl AsExecutionPlan for PhysicalPlanNode {
465465
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
466466
let agg_udf = registry.udaf(udaf_name)?;
467467
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name)
468-
.map(|func| func as Arc<dyn AggregateExpr>)
469468
}
470469
}
471470
}).transpose()?.ok_or_else(|| {

0 commit comments

Comments
 (0)