Skip to content

Commit a300aae

Browse files
author
Jiayu Liu
committed
add window expr
1 parent aa26112 commit a300aae

File tree

21 files changed

+1236
-103
lines changed

21 files changed

+1236
-103
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ message LogicalExprNode {
3939

4040
ScalarValue literal = 3;
4141

42-
4342
// binary expressions
4443
BinaryExprNode binary_expr = 4;
4544

@@ -60,6 +59,9 @@ message LogicalExprNode {
6059
bool wildcard = 15;
6160
ScalarFunctionNode scalar_function = 16;
6261
TryCastNode try_cast = 17;
62+
63+
// window expressions
64+
WindowExprNode window_expr = 18;
6365
}
6466
}
6567

@@ -151,6 +153,25 @@ message AggregateExprNode {
151153
LogicalExprNode expr = 2;
152154
}
153155

156+
enum BuiltInWindowFunction {
157+
ROW_NUMBER = 0;
158+
RANK = 1;
159+
DENSE_RANK = 2;
160+
LAG = 3;
161+
LEAD = 4;
162+
FIRST_VALUE = 5;
163+
LAST_VALUE = 6;
164+
}
165+
166+
message WindowExprNode {
167+
oneof window_function {
168+
AggregateFunction aggr_function = 1;
169+
BuiltInWindowFunction built_in_function = 2;
170+
// udaf = 3
171+
}
172+
LogicalExprNode expr = 4;
173+
}
174+
154175
message BetweenNode {
155176
LogicalExprNode expr = 1;
156177
bool negated = 2;
@@ -200,6 +221,7 @@ message LogicalPlanNode {
200221
EmptyRelationNode empty_relation = 10;
201222
CreateExternalTableNode create_external_table = 11;
202223
ExplainNode explain = 12;
224+
WindowNode window = 13;
203225
}
204226
}
205227

@@ -288,6 +310,49 @@ message AggregateNode {
288310
repeated LogicalExprNode aggr_expr = 3;
289311
}
290312

313+
message WindowNode {
314+
LogicalPlanNode input = 1;
315+
repeated LogicalExprNode window_expr = 2;
316+
repeated LogicalExprNode partition_by_expr = 3;
317+
repeated LogicalExprNode order_by_expr = 4;
318+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
319+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
320+
oneof window_frame {
321+
WindowFrame frame = 5;
322+
}
323+
}
324+
325+
enum WindowFrameUnits {
326+
ROWS = 0;
327+
RANGE = 1;
328+
GROUPS = 2;
329+
}
330+
331+
message WindowFrame {
332+
WindowFrameUnits window_frame_units = 1;
333+
WindowFrameBound start_bound = 2;
334+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
335+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
336+
oneof end_bound {
337+
WindowFrameBound bound = 3;
338+
}
339+
}
340+
341+
enum WindowFrameBoundType {
342+
CURRENT_ROW = 0;
343+
PRECEDING = 1;
344+
FOLLOWING = 2;
345+
}
346+
347+
message WindowFrameBound {
348+
WindowFrameBoundType window_frame_bound_type = 1;
349+
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
350+
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
351+
oneof bound_value {
352+
uint64 value = 2;
353+
}
354+
}
355+
291356
enum JoinType {
292357
INNER = 0;
293358
LEFT = 1;
@@ -334,6 +399,7 @@ message PhysicalPlanNode {
334399
MergeExecNode merge = 14;
335400
UnresolvedShuffleExecNode unresolved = 15;
336401
RepartitionExecNode repartition = 16;
402+
WindowAggExecNode window = 17;
337403
}
338404
}
339405

@@ -399,6 +465,13 @@ enum AggregateMode {
399465
FINAL_PARTITIONED = 2;
400466
}
401467

468+
message WindowAggExecNode {
469+
PhysicalPlanNode input = 1;
470+
repeated LogicalExprNode window_expr = 2;
471+
repeated string window_expr_name = 3;
472+
Schema input_schema = 4;
473+
}
474+
402475
message HashAggregateExecNode {
403476
repeated LogicalExprNode group_expr = 1;
404477
repeated LogicalExprNode aggr_expr = 2;

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 179 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818
//! Serde code to convert from protocol buffers to Rust data structures.
1919
20+
use crate::error::BallistaError;
21+
use crate::serde::{proto_error, protobuf};
22+
use crate::{convert_box_required, convert_required};
23+
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
2024
use std::{
2125
convert::{From, TryInto},
2226
unimplemented,
2327
};
2428

25-
use crate::error::BallistaError;
26-
use crate::serde::{proto_error, protobuf};
27-
use crate::{convert_box_required, convert_required};
28-
2929
use arrow::datatypes::{DataType, Field, Schema};
3030
use datafusion::logical_plan::{
3131
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
3232
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
3333
};
3434
use datafusion::physical_plan::aggregates::AggregateFunction;
3535
use datafusion::physical_plan::csv::CsvReadOptions;
36+
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
3637
use datafusion::scalar::ScalarValue;
3738
use protobuf::logical_plan_node::LogicalPlanType;
3839
use protobuf::{logical_expr_node::ExprType, scalar_type};
@@ -75,6 +76,33 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
7576
.build()
7677
.map_err(|e| e.into())
7778
}
79+
LogicalPlanType::Window(window) => {
80+
let input: LogicalPlan = convert_box_required!(window.input)?;
81+
let window_expr = window
82+
.window_expr
83+
.iter()
84+
.map(|expr| expr.try_into())
85+
.collect::<Result<Vec<_>, _>>()?;
86+
87+
// let partition_by_expr = window
88+
// .partition_by_expr
89+
// .iter()
90+
// .map(|expr| expr.try_into())
91+
// .collect::<Result<Vec<_>, _>>()?;
92+
// let order_by_expr = window
93+
// .order_by_expr
94+
// .iter()
95+
// .map(|expr| expr.try_into())
96+
// .collect::<Result<Vec<_>, _>>()?;
97+
// // FIXME parse the window_frame data
98+
// let window_frame = None;
99+
LogicalPlanBuilder::from(&input)
100+
.window(
101+
window_expr, /*, partition_by_expr, order_by_expr, window_frame*/
102+
)?
103+
.build()
104+
.map_err(|e| e.into())
105+
}
78106
LogicalPlanType::Aggregate(aggregate) => {
79107
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
80108
let group_expr = aggregate
@@ -871,7 +899,10 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
871899
type Error = BallistaError;
872900

873901
fn try_into(self) -> Result<Expr, Self::Error> {
902+
use datafusion::physical_plan::window_functions;
874903
use protobuf::logical_expr_node::ExprType;
904+
use protobuf::window_expr_node;
905+
use protobuf::WindowExprNode;
875906

876907
let expr_type = self
877908
.expr_type
@@ -889,6 +920,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
889920
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
890921
Ok(Expr::Literal(scalar_value))
891922
}
923+
ExprType::WindowExpr(expr) => {
924+
let window_function = expr
925+
.window_function
926+
.as_ref()
927+
.ok_or_else(|| proto_error("Received empty window function"))?;
928+
match window_function {
929+
window_expr_node::WindowFunction::AggrFunction(i) => {
930+
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
931+
.ok_or_else(|| {
932+
proto_error(format!(
933+
"Received an unknown aggregate window function: {}",
934+
i
935+
))
936+
})?;
937+
938+
Ok(Expr::WindowFunction {
939+
fun: window_functions::WindowFunction::AggregateFunction(
940+
AggregateFunction::from(aggr_function),
941+
),
942+
args: vec![parse_required_expr(&expr.expr)?],
943+
})
944+
}
945+
window_expr_node::WindowFunction::BuiltInFunction(i) => {
946+
let built_in_function =
947+
protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
948+
|| {
949+
proto_error(format!(
950+
"Received an unknown built-in window function: {}",
951+
i
952+
))
953+
},
954+
)?;
955+
956+
Ok(Expr::WindowFunction {
957+
fun: window_functions::WindowFunction::BuiltInWindowFunction(
958+
BuiltInWindowFunction::from(built_in_function),
959+
),
960+
args: vec![parse_required_expr(&expr.expr)?],
961+
})
962+
}
963+
}
964+
}
892965
ExprType::AggregateExpr(expr) => {
893966
let aggr_function =
894967
protobuf::AggregateFunction::from_i32(expr.aggr_function)
@@ -898,13 +971,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
898971
expr.aggr_function
899972
))
900973
})?;
901-
let fun = match aggr_function {
902-
protobuf::AggregateFunction::Min => AggregateFunction::Min,
903-
protobuf::AggregateFunction::Max => AggregateFunction::Max,
904-
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
905-
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
906-
protobuf::AggregateFunction::Count => AggregateFunction::Count,
907-
};
974+
let fun = AggregateFunction::from(aggr_function);
908975

909976
Ok(Expr::AggregateFunction {
910977
fun,
@@ -1152,6 +1219,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
11521219
}
11531220

11541221
use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
1222+
use datafusion::physical_plan::{aggregates, windows};
11551223
use datafusion::prelude::{
11561224
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
11571225
};
@@ -1202,3 +1270,103 @@ fn parse_optional_expr(
12021270
None => Ok(None),
12031271
}
12041272
}
1273+
1274+
impl From<protobuf::WindowFrameUnits> for WindowFrameUnits {
1275+
fn from(units: protobuf::WindowFrameUnits) -> Self {
1276+
match units {
1277+
protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
1278+
protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range,
1279+
protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
1280+
}
1281+
}
1282+
}
1283+
1284+
impl TryFrom<protobuf::WindowFrameBound> for WindowFrameBound {
1285+
type Error = BallistaError;
1286+
1287+
fn try_from(bound: protobuf::WindowFrameBound) -> Result<Self, Self::Error> {
1288+
let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| {
1289+
proto_error(format!(
1290+
"Received a WindowFrameBound message with unknown WindowFrameBoundType {}",
1291+
bound.window_frame_bound_type
1292+
))
1293+
})?.into();
1294+
match bound_type {
1295+
protobuf::WindowFrameBoundType::CurrentRow => {
1296+
Ok(WindowFrameBound::CurrentRow)
1297+
}
1298+
protobuf::WindowFrameBoundType::Preceding => {
1299+
// FIXME implement bound value parsing
1300+
Ok(WindowFrameBound::Preceding(Some(1)))
1301+
}
1302+
protobuf::WindowFrameBoundType::Following => {
1303+
// FIXME implement bound value parsing
1304+
Ok(WindowFrameBound::Following(Some(1)))
1305+
}
1306+
}
1307+
}
1308+
}
1309+
1310+
impl TryFrom<protobuf::WindowFrame> for WindowFrame {
1311+
type Error = BallistaError;
1312+
1313+
fn try_from(window: protobuf::WindowFrame) -> Result<Self, Self::Error> {
1314+
let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units)
1315+
.ok_or_else(|| {
1316+
proto_error(format!(
1317+
"Received a WindowFrame message with unknown WindowFrameUnits {}",
1318+
window.window_frame_units
1319+
))
1320+
})?
1321+
.into();
1322+
let start_bound = window
1323+
.start_bound
1324+
.ok_or_else(|| {
1325+
proto_error(
1326+
"Received a WindowFrame message with no start_bound".to_owned(),
1327+
)
1328+
})?
1329+
.try_into()?;
1330+
// FIXME parse end bound
1331+
let end_bound = None;
1332+
Ok(WindowFrame {
1333+
units,
1334+
start_bound,
1335+
end_bound,
1336+
})
1337+
}
1338+
}
1339+
1340+
impl From<protobuf::AggregateFunction> for AggregateFunction {
1341+
fn from(aggr_function: protobuf::AggregateFunction) -> Self {
1342+
match aggr_function {
1343+
protobuf::AggregateFunction::Min => AggregateFunction::Min,
1344+
protobuf::AggregateFunction::Max => AggregateFunction::Max,
1345+
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
1346+
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
1347+
protobuf::AggregateFunction::Count => AggregateFunction::Count,
1348+
}
1349+
}
1350+
}
1351+
1352+
impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
1353+
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
1354+
match built_in_function {
1355+
protobuf::BuiltInWindowFunction::RowNumber => {
1356+
BuiltInWindowFunction::RowNumber
1357+
}
1358+
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
1359+
protobuf::BuiltInWindowFunction::DenseRank => {
1360+
BuiltInWindowFunction::DenseRank
1361+
}
1362+
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
1363+
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
1364+
protobuf::BuiltInWindowFunction::FirstValue => {
1365+
BuiltInWindowFunction::FirstValue
1366+
}
1367+
protobuf::BuiltInWindowFunction::LastValue => {
1368+
BuiltInWindowFunction::LastValue
1369+
}
1370+
}
1371+
}
1372+
}

0 commit comments

Comments
 (0)