Skip to content

Commit db4f098

Browse files
jimexistalamb
andauthored
Add window expression part 1 - logical and physical planning, structure, to/from proto, and explain, for empty over clause only (#334)
* add window expr * fix unused imports * fix clippy * fix unit test * Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * adding more built-in functions * adding filter by todo * enrich unit test * update * add more tests * fix test * fix unit test * fix error * fix unit test * fix unit test * use upper case * fix unit test * comment out test Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 913bf86 commit db4f098

File tree

21 files changed

+1498
-102
lines changed

21 files changed

+1498
-102
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ message LogicalExprNode {
3939

4040
ScalarValue literal = 3;
4141

42-
4342
// binary expressions
4443
BinaryExprNode binary_expr = 4;
4544

@@ -60,6 +59,9 @@ message LogicalExprNode {
6059
bool wildcard = 15;
6160
ScalarFunctionNode scalar_function = 16;
6261
TryCastNode try_cast = 17;
62+
63+
// window expressions
64+
WindowExprNode window_expr = 18;
6365
}
6466
}
6567

@@ -151,6 +153,29 @@ message AggregateExprNode {
151153
LogicalExprNode expr = 2;
152154
}
153155

156+
enum BuiltInWindowFunction {
157+
ROW_NUMBER = 0;
158+
RANK = 1;
159+
DENSE_RANK = 2;
160+
PERCENT_RANK = 3;
161+
CUME_DIST = 4;
162+
NTILE = 5;
163+
LAG = 6;
164+
LEAD = 7;
165+
FIRST_VALUE = 8;
166+
LAST_VALUE = 9;
167+
NTH_VALUE = 10;
168+
}
169+
170+
message WindowExprNode {
171+
oneof window_function {
172+
AggregateFunction aggr_function = 1;
173+
BuiltInWindowFunction built_in_function = 2;
174+
// udaf = 3
175+
}
176+
LogicalExprNode expr = 4;
177+
}
178+
154179
message BetweenNode {
155180
LogicalExprNode expr = 1;
156181
bool negated = 2;
@@ -200,6 +225,7 @@ message LogicalPlanNode {
200225
EmptyRelationNode empty_relation = 10;
201226
CreateExternalTableNode create_external_table = 11;
202227
ExplainNode explain = 12;
228+
WindowNode window = 13;
203229
}
204230
}
205231

@@ -288,6 +314,50 @@ message AggregateNode {
288314
repeated LogicalExprNode aggr_expr = 3;
289315
}
290316

317+
message WindowNode {
318+
LogicalPlanNode input = 1;
319+
repeated LogicalExprNode window_expr = 2;
320+
repeated LogicalExprNode partition_by_expr = 3;
321+
repeated LogicalExprNode order_by_expr = 4;
322+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
323+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
324+
oneof window_frame {
325+
WindowFrame frame = 5;
326+
}
327+
// TODO add filter by expr
328+
}
329+
330+
enum WindowFrameUnits {
331+
ROWS = 0;
332+
RANGE = 1;
333+
GROUPS = 2;
334+
}
335+
336+
message WindowFrame {
337+
WindowFrameUnits window_frame_units = 1;
338+
WindowFrameBound start_bound = 2;
339+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
340+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
341+
oneof end_bound {
342+
WindowFrameBound bound = 3;
343+
}
344+
}
345+
346+
enum WindowFrameBoundType {
347+
CURRENT_ROW = 0;
348+
PRECEDING = 1;
349+
FOLLOWING = 2;
350+
}
351+
352+
message WindowFrameBound {
353+
WindowFrameBoundType window_frame_bound_type = 1;
354+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
355+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
356+
oneof bound_value {
357+
uint64 value = 2;
358+
}
359+
}
360+
291361
enum JoinType {
292362
INNER = 0;
293363
LEFT = 1;
@@ -334,6 +404,7 @@ message PhysicalPlanNode {
334404
MergeExecNode merge = 14;
335405
UnresolvedShuffleExecNode unresolved = 15;
336406
RepartitionExecNode repartition = 16;
407+
WindowAggExecNode window = 17;
337408
}
338409
}
339410

@@ -399,6 +470,13 @@ enum AggregateMode {
399470
FINAL_PARTITIONED = 2;
400471
}
401472

473+
message WindowAggExecNode {
474+
PhysicalPlanNode input = 1;
475+
repeated LogicalExprNode window_expr = 2;
476+
repeated string window_expr_name = 3;
477+
Schema input_schema = 4;
478+
}
479+
402480
message HashAggregateExecNode {
403481
repeated LogicalExprNode group_expr = 1;
404482
repeated LogicalExprNode aggr_expr = 2;

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 186 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818
//! Serde code to convert from protocol buffers to Rust data structures.
1919
20+
use crate::error::BallistaError;
21+
use crate::serde::{proto_error, protobuf};
22+
use crate::{convert_box_required, convert_required};
23+
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
2024
use std::{
2125
convert::{From, TryInto},
2226
unimplemented,
2327
};
2428

25-
use crate::error::BallistaError;
26-
use crate::serde::{proto_error, protobuf};
27-
use crate::{convert_box_required, convert_required};
28-
2929
use arrow::datatypes::{DataType, Field, Schema};
3030
use datafusion::logical_plan::{
3131
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
3232
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
3333
};
3434
use datafusion::physical_plan::aggregates::AggregateFunction;
3535
use datafusion::physical_plan::csv::CsvReadOptions;
36+
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
3637
use datafusion::scalar::ScalarValue;
3738
use protobuf::logical_plan_node::LogicalPlanType;
3839
use protobuf::{logical_expr_node::ExprType, scalar_type};
@@ -75,6 +76,34 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
7576
.build()
7677
.map_err(|e| e.into())
7778
}
79+
LogicalPlanType::Window(window) => {
80+
let input: LogicalPlan = convert_box_required!(window.input)?;
81+
let window_expr = window
82+
.window_expr
83+
.iter()
84+
.map(|expr| expr.try_into())
85+
.collect::<Result<Vec<_>, _>>()?;
86+
87+
// let partition_by_expr = window
88+
// .partition_by_expr
89+
// .iter()
90+
// .map(|expr| expr.try_into())
91+
// .collect::<Result<Vec<_>, _>>()?;
92+
// let order_by_expr = window
93+
// .order_by_expr
94+
// .iter()
95+
// .map(|expr| expr.try_into())
96+
// .collect::<Result<Vec<_>, _>>()?;
97+
// // FIXME: add filter by expr
98+
// // FIXME: parse the window_frame data
99+
// let window_frame = None;
100+
LogicalPlanBuilder::from(&input)
101+
.window(
102+
window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/
103+
)?
104+
.build()
105+
.map_err(|e| e.into())
106+
}
78107
LogicalPlanType::Aggregate(aggregate) => {
79108
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
80109
let group_expr = aggregate
@@ -871,7 +900,10 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
871900
type Error = BallistaError;
872901

873902
fn try_into(self) -> Result<Expr, Self::Error> {
903+
use datafusion::physical_plan::window_functions;
874904
use protobuf::logical_expr_node::ExprType;
905+
use protobuf::window_expr_node;
906+
use protobuf::WindowExprNode;
875907

876908
let expr_type = self
877909
.expr_type
@@ -889,6 +921,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
889921
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
890922
Ok(Expr::Literal(scalar_value))
891923
}
924+
ExprType::WindowExpr(expr) => {
925+
let window_function = expr
926+
.window_function
927+
.as_ref()
928+
.ok_or_else(|| proto_error("Received empty window function"))?;
929+
match window_function {
930+
window_expr_node::WindowFunction::AggrFunction(i) => {
931+
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
932+
.ok_or_else(|| {
933+
proto_error(format!(
934+
"Received an unknown aggregate window function: {}",
935+
i
936+
))
937+
})?;
938+
939+
Ok(Expr::WindowFunction {
940+
fun: window_functions::WindowFunction::AggregateFunction(
941+
AggregateFunction::from(aggr_function),
942+
),
943+
args: vec![parse_required_expr(&expr.expr)?],
944+
})
945+
}
946+
window_expr_node::WindowFunction::BuiltInFunction(i) => {
947+
let built_in_function =
948+
protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
949+
|| {
950+
proto_error(format!(
951+
"Received an unknown built-in window function: {}",
952+
i
953+
))
954+
},
955+
)?;
956+
957+
Ok(Expr::WindowFunction {
958+
fun: window_functions::WindowFunction::BuiltInWindowFunction(
959+
BuiltInWindowFunction::from(built_in_function),
960+
),
961+
args: vec![parse_required_expr(&expr.expr)?],
962+
})
963+
}
964+
}
965+
}
892966
ExprType::AggregateExpr(expr) => {
893967
let aggr_function =
894968
protobuf::AggregateFunction::from_i32(expr.aggr_function)
@@ -898,13 +972,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
898972
expr.aggr_function
899973
))
900974
})?;
901-
let fun = match aggr_function {
902-
protobuf::AggregateFunction::Min => AggregateFunction::Min,
903-
protobuf::AggregateFunction::Max => AggregateFunction::Max,
904-
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
905-
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
906-
protobuf::AggregateFunction::Count => AggregateFunction::Count,
907-
};
975+
let fun = AggregateFunction::from(aggr_function);
908976

909977
Ok(Expr::AggregateFunction {
910978
fun,
@@ -1152,6 +1220,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
11521220
}
11531221

11541222
use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
1223+
use datafusion::physical_plan::{aggregates, windows};
11551224
use datafusion::prelude::{
11561225
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
11571226
};
@@ -1202,3 +1271,109 @@ fn parse_optional_expr(
12021271
None => Ok(None),
12031272
}
12041273
}
1274+
1275+
impl From<protobuf::WindowFrameUnits> for WindowFrameUnits {
1276+
fn from(units: protobuf::WindowFrameUnits) -> Self {
1277+
match units {
1278+
protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1279+
protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range,
1280+
protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
1281+
}
1282+
}
1283+
}
1284+
1285+
impl TryFrom<protobuf::WindowFrameBound> for WindowFrameBound {
1286+
type Error = BallistaError;
1287+
1288+
fn try_from(bound: protobuf::WindowFrameBound) -> Result<Self, Self::Error> {
1289+
let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| {
1290+
proto_error(format!(
1291+
"Received a WindowFrameBound message with unknown WindowFrameBoundType {}",
1292+
bound.window_frame_bound_type
1293+
))
1294+
})?;
1295+
match bound_type {
1296+
protobuf::WindowFrameBoundType::CurrentRow => {
1297+
Ok(WindowFrameBound::CurrentRow)
1298+
}
1299+
protobuf::WindowFrameBoundType::Preceding => {
1300+
// FIXME implement bound value parsing
1301+
Ok(WindowFrameBound::Preceding(Some(1)))
1302+
}
1303+
protobuf::WindowFrameBoundType::Following => {
1304+
// FIXME implement bound value parsing
1305+
Ok(WindowFrameBound::Following(Some(1)))
1306+
}
1307+
}
1308+
}
1309+
}
1310+
1311+
impl TryFrom<protobuf::WindowFrame> for WindowFrame {
1312+
type Error = BallistaError;
1313+
1314+
fn try_from(window: protobuf::WindowFrame) -> Result<Self, Self::Error> {
1315+
let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units)
1316+
.ok_or_else(|| {
1317+
proto_error(format!(
1318+
"Received a WindowFrame message with unknown WindowFrameUnits {}",
1319+
window.window_frame_units
1320+
))
1321+
})?
1322+
.into();
1323+
let start_bound = window
1324+
.start_bound
1325+
.ok_or_else(|| {
1326+
proto_error(
1327+
"Received a WindowFrame message with no start_bound".to_owned(),
1328+
)
1329+
})?
1330+
.try_into()?;
1331+
// FIXME parse end bound
1332+
let end_bound = None;
1333+
Ok(WindowFrame {
1334+
units,
1335+
start_bound,
1336+
end_bound,
1337+
})
1338+
}
1339+
}
1340+
1341+
impl From<protobuf::AggregateFunction> for AggregateFunction {
1342+
fn from(aggr_function: protobuf::AggregateFunction) -> Self {
1343+
match aggr_function {
1344+
protobuf::AggregateFunction::Min => AggregateFunction::Min,
1345+
protobuf::AggregateFunction::Max => AggregateFunction::Max,
1346+
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
1347+
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
1348+
protobuf::AggregateFunction::Count => AggregateFunction::Count,
1349+
}
1350+
}
1351+
}
1352+
1353+
impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
1354+
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
1355+
match built_in_function {
1356+
protobuf::BuiltInWindowFunction::RowNumber => {
1357+
BuiltInWindowFunction::RowNumber
1358+
}
1359+
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
1360+
protobuf::BuiltInWindowFunction::PercentRank => {
1361+
BuiltInWindowFunction::PercentRank
1362+
}
1363+
protobuf::BuiltInWindowFunction::DenseRank => {
1364+
BuiltInWindowFunction::DenseRank
1365+
}
1366+
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
1367+
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
1368+
protobuf::BuiltInWindowFunction::FirstValue => {
1369+
BuiltInWindowFunction::FirstValue
1370+
}
1371+
protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
1372+
protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
1373+
protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
1374+
protobuf::BuiltInWindowFunction::LastValue => {
1375+
BuiltInWindowFunction::LastValue
1376+
}
1377+
}
1378+
}
1379+
}

0 commit comments

Comments
 (0)