Skip to content

Commit 90daf11

Browse files
author
Jiayu Liu
committed
Squashed commit of the following:
commit 7fb3640 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 16:38:25 2021 +0800 row number done commit 1723926 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 16:05:50 2021 +0800 add row number commit bf5b8a5 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 15:04:49 2021 +0800 save commit d2ce852 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 14:53:05 2021 +0800 add streams commit 0a861a7 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 22:28:34 2021 +0800 save stream commit a9121af Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 22:01:51 2021 +0800 update unit test commit 2af2a27 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 14:25:12 2021 +0800 fix unit test commit bb57c76 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 14:23:34 2021 +0800 use upper case commit 5d96e52 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 14:16:16 2021 +0800 fix unit test commit 1ecae8f Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 12:27:26 2021 +0800 fix unit test commit bc2271d Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 10:04:29 2021 +0800 fix error commit 880b94f Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 08:24:00 2021 +0800 fix unit test commit 4e792e1 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 08:05:17 2021 +0800 fix test commit c36c04a Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Fri May 21 00:07:54 2021 +0800 add more tests commit f5e64de Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 23:41:36 2021 +0800 update commit a1eae86 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 23:36:15 2021 +0800 enrich unit test commit 0d2a214 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 23:25:43 2021 +0800 adding filter by todo commit 8b486d5 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 20 23:17:22 2021 +0800 adding more built-in functions commit abf08cd Author: Jiayu Liu <Jimexist@users.noreply.github.com> Date: Thu May 20 22:36:27 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> commit 0cbca53 Author: Jiayu Liu <Jimexist@users.noreply.github.com> Date: Thu May 20 22:34:57 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> commit 831c069 Author: Jiayu Liu <Jimexist@users.noreply.github.com> Date: Thu May 20 22:34:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> commit f70c739 Author: Jiayu Liu <Jimexist@users.noreply.github.com> Date: Thu May 20 22:33:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> commit 3ee87aa Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Wed May 19 22:55:08 2021 +0800 fix unit test commit 5c4d92d Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Wed May 19 22:48:26 2021 +0800 fix clippy commit a0b7526 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Wed May 19 22:46:38 2021 +0800 fix unused imports commit 1d3b076 Author: Jiayu Liu <jiayu.liu@airbnb.com> Date: Thu May 13 18:51:14 2021 +0800 add window expr
1 parent db4f098 commit 90daf11

File tree

10 files changed

+527
-51
lines changed

10 files changed

+527
-51
lines changed

datafusion/src/execution/context.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,65 @@ mod tests {
12681268
Ok(())
12691269
}
12701270

1271+
#[tokio::test]
1272+
async fn window() -> Result<()> {
1273+
let results = execute("SELECT c1, MAX(c2) OVER () FROM test", 4).await?;
1274+
assert_eq!(results.len(), 1);
1275+
Ok(())
1276+
}
1277+
1278+
#[tokio::test]
1279+
async fn window_plan() -> Result<()> {
1280+
let schema = Schema::new(vec![
1281+
Field::new("a", DataType::Int32, false),
1282+
Field::new("b", DataType::Int32, false),
1283+
]);
1284+
1285+
let batch = RecordBatch::try_new(
1286+
Arc::new(schema.clone()),
1287+
vec![
1288+
Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
1289+
Arc::new(Int32Array::from(vec![2, 12, 12, 120])),
1290+
],
1291+
)?;
1292+
let mut ctx = ExecutionContext::new();
1293+
let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?;
1294+
ctx.register_table("t", Arc::new(provider))?;
1295+
1296+
let logical_plan = ctx.create_logical_plan("SELECT a, MAX(b) OVER () FROM t")?;
1297+
let opt_plan = ctx.optimize(&logical_plan)?;
1298+
let physical_plan = ctx.create_physical_plan(&opt_plan)?;
1299+
assert_eq!(
1300+
format!("{:?}", physical_plan),
1301+
"ProjectionExec { \
1302+
expr: [(Column { name: \"a\" }, \"a\"), (Column { name: \"MAX(b)\" }, \"MAX(b)\")], \
1303+
schema: Schema { \
1304+
fields: [\
1305+
Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \
1306+
Field { name: \"MAX(b)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }], \
1307+
metadata: {} }, \
1308+
input: RepartitionExec { \
1309+
input: WindowAggExec { \
1310+
input: partitions: [...]\
1311+
schema: Schema { fields: [\
1312+
Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \
1313+
Field { name: \"b\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }], \
1314+
metadata: {} }\
1315+
projection: Some([0, 1]), \
1316+
window_expr: [AggregateWindowExpr { \
1317+
aggregate: Max { name: \"MAX(b)\", data_type: Int32, nullable: true, expr: Column { name: \"b\" } } }], \
1318+
schema: Schema { fields: [\
1319+
Field { name: \"MAX(b)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \
1320+
Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \
1321+
Field { name: \"b\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }], metadata: {} }, \
1322+
input_schema: Schema { fields: [\
1323+
Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \
1324+
Field { name: \"b\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }], metadata: {} } }, \
1325+
partitioning: RoundRobinBatch(16), channels: Mutex { data: {} } } }"
1326+
);
1327+
Ok(())
1328+
}
1329+
12711330
#[tokio::test]
12721331
async fn aggregate() -> Result<()> {
12731332
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;

datafusion/src/physical_plan/expressions/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mod min_max;
4141
mod negative;
4242
mod not;
4343
mod nullif;
44+
mod row_number;
4445
mod sum;
4546
mod try_cast;
4647

@@ -58,6 +59,7 @@ pub use min_max::{Max, Min};
5859
pub use negative::{negative, NegativeExpr};
5960
pub use not::{not, NotExpr};
6061
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
62+
pub use row_number::RowNumber;
6163
pub use sum::{sum_return_type, Sum};
6264
pub use try_cast::{try_cast, TryCastExpr};
6365
/// returns the name of the state
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Defines physical expressions that can evaluated at runtime during query execution
19+
20+
use std::any::Any;
21+
use std::convert::TryFrom;
22+
use std::sync::Arc;
23+
24+
use crate::error::{DataFusionError, Result};
25+
use crate::physical_plan::{
26+
Accumulator, AggregateExpr, BuiltInWindowFunctionExpr, PhysicalExpr,
27+
};
28+
use crate::scalar::ScalarValue;
29+
use arrow::compute;
30+
use arrow::datatypes::{DataType, TimeUnit};
31+
use arrow::{
32+
array::{
33+
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
34+
Int8Array, LargeStringArray, StringArray, TimestampMicrosecondArray,
35+
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
36+
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
37+
},
38+
datatypes::Field,
39+
};
40+
41+
/// row_number expression
42+
#[derive(Debug)]
43+
pub struct RowNumber {
44+
name: String,
45+
}
46+
47+
impl RowNumber {
48+
/// Create a new MAX aggregate function
49+
pub fn new(name: String) -> Self {
50+
Self { name }
51+
}
52+
}
53+
54+
impl BuiltInWindowFunctionExpr for RowNumber {
55+
/// Return a reference to Any that can be used for downcasting
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
60+
fn field(&self) -> Result<Field> {
61+
let nullable = false;
62+
let data_type = DataType::UInt64;
63+
Ok(Field::new(&self.name, data_type, nullable))
64+
}
65+
66+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
67+
vec![]
68+
}
69+
70+
fn name(&self) -> &str {
71+
&self.name
72+
}
73+
}

datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ impl GroupedHashAggregateStream {
712712
tx.send(result)
713713
});
714714

715-
GroupedHashAggregateStream {
715+
Self {
716716
schema,
717717
output: rx,
718718
finished: false,
@@ -825,7 +825,8 @@ fn aggregate_expressions(
825825
}
826826

827827
pin_project! {
828-
struct HashAggregateStream {
828+
/// stream struct for hash aggregation
829+
pub struct HashAggregateStream {
829830
schema: SchemaRef,
830831
#[pin]
831832
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
@@ -878,7 +879,7 @@ impl HashAggregateStream {
878879
tx.send(result)
879880
});
880881

881-
HashAggregateStream {
882+
Self {
882883
schema,
883884
output: rx,
884885
finished: false,

datafusion/src/physical_plan/mod.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,41 @@ pub trait WindowExpr: Send + Sync + Debug {
457457
fn name(&self) -> &str {
458458
"WindowExpr: default name"
459459
}
460+
461+
/// the accumulator used to accumulate values from the expressions.
462+
/// the accumulator expects the same number of arguments as `expressions` and must
463+
/// return states with the same description as `state_fields`
464+
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
465+
466+
/// expressions that are passed to the WindowAccumulator.
467+
/// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
468+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
469+
}
470+
471+
/// A window expression that is a built-in window function
472+
pub trait BuiltInWindowFunctionExpr: Send + Sync + Debug {
473+
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
474+
/// downcast to a specific implementation.
475+
fn as_any(&self) -> &dyn Any;
476+
477+
/// the field of the final result of this aggregation.
478+
fn field(&self) -> Result<Field>;
479+
480+
/// expressions that are passed to the Accumulator.
481+
/// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
482+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
483+
484+
/// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default
485+
/// implementation returns placeholder text.
486+
fn name(&self) -> &str {
487+
"BuiltInWindowFunctionExpr: default name"
488+
}
460489
}
461490

462491
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
463-
/// generically accumulates values. An accumulator knows how to:
492+
/// generically accumulates values.
493+
///
494+
/// An accumulator knows how to:
464495
/// * update its state from inputs via `update`
465496
/// * convert its internal state to a vector of scalar values
466497
/// * update its state from multiple accumulators' states via `merge`
@@ -509,6 +540,43 @@ pub trait Accumulator: Send + Sync + Debug {
509540
fn evaluate(&self) -> Result<ScalarValue>;
510541
}
511542

543+
/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple
544+
/// rows and generically accumulates values.
545+
///
546+
/// An accumulator knows how to:
547+
/// * update its state from inputs via `update`
548+
/// * convert its internal state to a vector of scalar values
549+
/// * update its state from multiple accumulators' states via `merge`
550+
/// * compute the final value from its internal state via `evaluate`
551+
pub trait WindowAccumulator: Send + Sync + Debug {
552+
/// scans the accumulator's state from a vector of scalars, similar to Accumulator it also
553+
/// optionally generates values.
554+
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<Vec<ScalarValue>>>;
555+
556+
/// scans the accumulator's state from a vector of arrays.
557+
fn scan_batch(
558+
&mut self,
559+
values: &[ArrayRef],
560+
) -> Result<Vec<Option<Vec<ScalarValue>>>> {
561+
if values.is_empty() {
562+
return Ok(vec![]);
563+
};
564+
(0..values[0].len())
565+
.map(|index: usize| {
566+
let v = values
567+
.iter()
568+
.map(|array| ScalarValue::try_from_array(array, index))
569+
.collect::<Result<Vec<_>>>()?;
570+
self.scan(&v)
571+
})
572+
.into_iter()
573+
.collect::<Result<Vec<_>>>()
574+
}
575+
576+
/// returns its value based on its current state.
577+
fn evaluate(&self) -> Result<Option<ScalarValue>>;
578+
}
579+
512580
pub mod aggregates;
513581
pub mod array_expressions;
514582
pub mod coalesce_batches;

datafusion/src/physical_plan/planner.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner {
147147
// Initially need to perform the aggregate and then merge the partitions
148148
let input_exec = self.create_initial_plan(input, ctx_state)?;
149149
let input_schema = input_exec.schema();
150-
let physical_input_schema = input_exec.as_ref().schema();
150+
151151
let logical_input_schema = input.as_ref().schema();
152+
let physical_input_schema = input_exec.as_ref().schema();
153+
152154
let window_expr = window_expr
153155
.iter()
154156
.map(|e| {

datafusion/src/physical_plan/sort.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ fn sort_batches(
227227
}
228228

229229
pin_project! {
230+
/// stream for sort plan
230231
struct SortStream {
231232
#[pin]
232233
output: futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,

datafusion/src/physical_plan/window_functions.rs

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,49 +143,64 @@ impl FromStr for BuiltInWindowFunction {
143143

144144
/// Returns the datatype of the window function
145145
pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataType> {
146+
match fun {
147+
WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
148+
WindowFunction::BuiltInWindowFunction(fun) => {
149+
return_type_for_built_in(fun, arg_types)
150+
}
151+
}
152+
}
153+
154+
/// Returns the datatype of the built-in window function
155+
pub(super) fn return_type_for_built_in(
156+
fun: &BuiltInWindowFunction,
157+
arg_types: &[DataType],
158+
) -> Result<DataType> {
146159
// Note that this function *must* return the same type that the respective physical expression returns
147160
// or the execution panics.
148161

149162
// verify that this is a valid set of data types for this function
150-
data_types(arg_types, &signature(fun))?;
163+
data_types(arg_types, &signature_for_built_in(fun))?;
151164

152165
match fun {
153-
WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
154-
WindowFunction::BuiltInWindowFunction(fun) => match fun {
155-
BuiltInWindowFunction::RowNumber
156-
| BuiltInWindowFunction::Rank
157-
| BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
158-
BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
159-
Ok(DataType::Float64)
160-
}
161-
BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
162-
BuiltInWindowFunction::Lag
163-
| BuiltInWindowFunction::Lead
164-
| BuiltInWindowFunction::FirstValue
165-
| BuiltInWindowFunction::LastValue
166-
| BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
167-
},
166+
BuiltInWindowFunction::RowNumber
167+
| BuiltInWindowFunction::Rank
168+
| BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
169+
BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
170+
Ok(DataType::Float64)
171+
}
172+
BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
173+
BuiltInWindowFunction::Lag
174+
| BuiltInWindowFunction::Lead
175+
| BuiltInWindowFunction::FirstValue
176+
| BuiltInWindowFunction::LastValue
177+
| BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
168178
}
169179
}
170180

171181
/// the signatures supported by the function `fun`.
172-
fn signature(fun: &WindowFunction) -> Signature {
173-
// note: the physical expression must accept the type returned by this function or the execution panics.
182+
pub fn signature(fun: &WindowFunction) -> Signature {
174183
match fun {
175184
WindowFunction::AggregateFunction(fun) => aggregates::signature(fun),
176-
WindowFunction::BuiltInWindowFunction(fun) => match fun {
177-
BuiltInWindowFunction::RowNumber
178-
| BuiltInWindowFunction::Rank
179-
| BuiltInWindowFunction::DenseRank
180-
| BuiltInWindowFunction::PercentRank
181-
| BuiltInWindowFunction::CumeDist => Signature::Any(0),
182-
BuiltInWindowFunction::Lag
183-
| BuiltInWindowFunction::Lead
184-
| BuiltInWindowFunction::FirstValue
185-
| BuiltInWindowFunction::LastValue => Signature::Any(1),
186-
BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
187-
BuiltInWindowFunction::NthValue => Signature::Any(2),
188-
},
185+
WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun),
186+
}
187+
}
188+
189+
/// the signatures supported by the built-in window function `fun`.
190+
pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
191+
// note: the physical expression must accept the type returned by this function or the execution panics.
192+
match fun {
193+
BuiltInWindowFunction::RowNumber
194+
| BuiltInWindowFunction::Rank
195+
| BuiltInWindowFunction::DenseRank
196+
| BuiltInWindowFunction::PercentRank
197+
| BuiltInWindowFunction::CumeDist => Signature::Any(0),
198+
BuiltInWindowFunction::Lag
199+
| BuiltInWindowFunction::Lead
200+
| BuiltInWindowFunction::FirstValue
201+
| BuiltInWindowFunction::LastValue => Signature::Any(1),
202+
BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
203+
BuiltInWindowFunction::NthValue => Signature::Any(2),
189204
}
190205
}
191206

0 commit comments

Comments
 (0)