Skip to content

Commit 11e9541

Browse files
author
Jiayu Liu
committed
add window expr
1 parent 1702d6c commit 11e9541

File tree

16 files changed

+493
-13
lines changed

16 files changed

+493
-13
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ message LogicalExprNode {
3939

4040
ScalarValue literal = 3;
4141

42-
4342
// binary expressions
4443
BinaryExprNode binary_expr = 4;
4544

@@ -60,6 +59,9 @@ message LogicalExprNode {
6059
bool wildcard = 15;
6160
ScalarFunctionNode scalar_function = 16;
6261
TryCastNode try_cast = 17;
62+
63+
// window expressions
64+
WindowExprNode window_expr = 18;
6365
}
6466
}
6567

@@ -151,6 +153,25 @@ message AggregateExprNode {
151153
LogicalExprNode expr = 2;
152154
}
153155

156+
enum BuiltInWindowFunction {
157+
ROW_NUMBER = 0;
158+
RANK = 1;
159+
DENSE_RANK = 2;
160+
LAG = 3;
161+
LEAD = 4;
162+
FIRST_VALUE = 5;
163+
LAST_VALUE = 6;
164+
}
165+
166+
message WindowExprNode {
167+
oneof window_function {
168+
AggregateFunction aggr_function = 1;
169+
BuiltInWindowFunction built_in_function = 2;
170+
// udaf = 3
171+
}
172+
LogicalExprNode expr = 4;
173+
}
174+
154175
message BetweenNode {
155176
LogicalExprNode expr = 1;
156177
bool negated = 2;
@@ -200,6 +221,7 @@ message LogicalPlanNode {
200221
EmptyRelationNode empty_relation = 10;
201222
CreateExternalTableNode create_external_table = 11;
202223
ExplainNode explain = 12;
224+
WindowNode window = 13;
203225
}
204226
}
205227

@@ -288,6 +310,12 @@ message AggregateNode {
288310
repeated LogicalExprNode aggr_expr = 3;
289311
}
290312

313+
message WindowNode {
314+
LogicalPlanNode input = 1;
315+
repeated LogicalExprNode partition_by_expr = 2;
316+
repeated LogicalExprNode order_by_expr = 3;
317+
}
318+
291319
enum JoinType {
292320
INNER = 0;
293321
LEFT = 1;

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion::logical_plan::{
3333
};
3434
use datafusion::physical_plan::aggregates::AggregateFunction;
3535
use datafusion::physical_plan::csv::CsvReadOptions;
36+
use datafusion::physical_plan::windows::BuiltInWindowFunction;
3637
use datafusion::scalar::ScalarValue;
3738
use protobuf::logical_plan_node::LogicalPlanType;
3839
use protobuf::{logical_expr_node::ExprType, scalar_type};
@@ -75,6 +76,23 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
7576
.build()
7677
.map_err(|e| e.into())
7778
}
79+
LogicalPlanType::Window(window) => {
80+
let input: LogicalPlan = convert_box_required!(window.input)?;
81+
let partition_by_expr = window
82+
.partition_by_expr
83+
.iter()
84+
.map(|expr| expr.try_into())
85+
.collect::<Result<Vec<_>, _>>()?;
86+
let order_by_expr = window
87+
.order_by_expr
88+
.iter()
89+
.map(|expr| expr.try_into())
90+
.collect::<Result<Vec<_>, _>>()?;
91+
LogicalPlanBuilder::from(&input)
92+
.window(partition_by_expr, order_by_expr)?
93+
.build()
94+
.map_err(|e| e.into())
95+
}
7896
LogicalPlanType::Aggregate(aggregate) => {
7997
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
8098
let group_expr = aggregate
@@ -872,6 +890,8 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
872890

873891
fn try_into(self) -> Result<Expr, Self::Error> {
874892
use protobuf::logical_expr_node::ExprType;
893+
use protobuf::window_expr_node;
894+
use protobuf::WindowExprNode;
875895

876896
let expr_type = self
877897
.expr_type
@@ -889,6 +909,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
889909
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
890910
Ok(Expr::Literal(scalar_value))
891911
}
912+
ExprType::WindowExpr(expr) => {
913+
let window_function = expr
914+
.window_function
915+
.as_ref()
916+
.ok_or_else(|| proto_error("Received empty window function"))?;
917+
match window_function {
918+
window_expr_node::WindowFunction::AggrFunction(i) => {
919+
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
920+
.ok_or_else(|| {
921+
proto_error(format!(
922+
"Received an unknown aggregate window function: {}",
923+
i
924+
))
925+
})?;
926+
927+
Ok(Expr::WindowFunction {
928+
fun: windows::WindowFunction::AggregateFunction(
929+
AggregateFunction::from(aggr_function),
930+
),
931+
args: vec![parse_required_expr(&expr.expr)?],
932+
})
933+
}
934+
window_expr_node::WindowFunction::BuiltInFunction(i) => {
935+
let built_in_function =
936+
protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
937+
|| {
938+
proto_error(format!(
939+
"Received an unknown aggregate window function: {}",
940+
i
941+
))
942+
},
943+
)?;
944+
945+
Ok(Expr::WindowFunction {
946+
fun: windows::WindowFunction::BuiltInWindowFunction(
947+
BuiltInWindowFunction::from(built_in_function),
948+
),
949+
args: vec![parse_required_expr(&expr.expr)?],
950+
})
951+
}
952+
}
953+
}
892954
ExprType::AggregateExpr(expr) => {
893955
let aggr_function =
894956
protobuf::AggregateFunction::from_i32(expr.aggr_function)
@@ -898,13 +960,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
898960
expr.aggr_function
899961
))
900962
})?;
901-
let fun = match aggr_function {
902-
protobuf::AggregateFunction::Min => AggregateFunction::Min,
903-
protobuf::AggregateFunction::Max => AggregateFunction::Max,
904-
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
905-
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
906-
protobuf::AggregateFunction::Count => AggregateFunction::Count,
907-
};
963+
let fun = AggregateFunction::from(aggr_function);
908964

909965
Ok(Expr::AggregateFunction {
910966
fun,
@@ -1152,6 +1208,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
11521208
}
11531209

11541210
use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
1211+
use datafusion::physical_plan::{aggregates, windows};
11551212
use datafusion::prelude::{
11561213
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
11571214
};
@@ -1202,3 +1259,37 @@ fn parse_optional_expr(
12021259
None => Ok(None),
12031260
}
12041261
}
1262+
1263+
impl From<protobuf::AggregateFunction> for AggregateFunction {
1264+
fn from(aggr_function: protobuf::AggregateFunction) -> Self {
1265+
match aggr_function {
1266+
protobuf::AggregateFunction::Min => AggregateFunction::Min,
1267+
protobuf::AggregateFunction::Max => AggregateFunction::Max,
1268+
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
1269+
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
1270+
protobuf::AggregateFunction::Count => AggregateFunction::Count,
1271+
}
1272+
}
1273+
}
1274+
1275+
impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
1276+
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
1277+
match built_in_function {
1278+
protobuf::BuiltInWindowFunction::RowNumber => {
1279+
BuiltInWindowFunction::RowNumber
1280+
}
1281+
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
1282+
protobuf::BuiltInWindowFunction::DenseRank => {
1283+
BuiltInWindowFunction::DenseRank
1284+
}
1285+
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
1286+
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
1287+
protobuf::BuiltInWindowFunction::FirstValue => {
1288+
BuiltInWindowFunction::FirstValue
1289+
}
1290+
protobuf::BuiltInWindowFunction::LastValue => {
1291+
BuiltInWindowFunction::LastValue
1292+
}
1293+
}
1294+
}
1295+
}

ballista/rust/core/src/serde/logical_plan/to_proto.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use arrow::datatypes::{DataType, Schema};
3131
use datafusion::datasource::CsvFile;
3232
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
3333
use datafusion::physical_plan::aggregates::AggregateFunction;
34+
use datafusion::physical_plan::windows::{BuiltInWindowFunction, WindowFunction};
3435
use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields};
3536
use protobuf::{
3637
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, Field,
@@ -772,6 +773,29 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
772773
))),
773774
})
774775
}
776+
LogicalPlan::Window {
777+
input,
778+
partition_by_expr,
779+
order_by_expr,
780+
..
781+
} => {
782+
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
783+
Ok(protobuf::LogicalPlanNode {
784+
logical_plan_type: Some(LogicalPlanType::Window(Box::new(
785+
protobuf::WindowNode {
786+
input: Some(Box::new(input)),
787+
partition_by_expr: partition_by_expr
788+
.iter()
789+
.map(|expr| expr.try_into())
790+
.collect::<Result<Vec<_>, BallistaError>>()?,
791+
order_by_expr: order_by_expr
792+
.iter()
793+
.map(|expr| expr.try_into())
794+
.collect::<Result<Vec<_>, BallistaError>>()?,
795+
},
796+
))),
797+
})
798+
}
775799
LogicalPlan::Aggregate {
776800
input,
777801
group_expr,
@@ -997,6 +1021,30 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
9971021
expr_type: Some(ExprType::BinaryExpr(binary_expr)),
9981022
})
9991023
}
1024+
Expr::WindowFunction {
1025+
ref fun, ref args, ..
1026+
} => {
1027+
let window_function = match fun {
1028+
WindowFunction::AggregateFunction(fun) => {
1029+
protobuf::window_expr_node::WindowFunction::AggrFunction(
1030+
protobuf::AggregateFunction::from(fun).into(),
1031+
)
1032+
}
1033+
WindowFunction::BuiltInWindowFunction(fun) => {
1034+
protobuf::window_expr_node::WindowFunction::BuiltInFunction(
1035+
protobuf::BuiltInWindowFunction::from(fun).into(),
1036+
)
1037+
}
1038+
};
1039+
let arg = &args[0];
1040+
let window_expr = Box::new(protobuf::WindowExprNode {
1041+
expr: Some(Box::new(arg.try_into()?)),
1042+
window_function: Some(window_function),
1043+
});
1044+
Ok(protobuf::LogicalExprNode {
1045+
expr_type: Some(ExprType::WindowExpr(window_expr)),
1046+
})
1047+
}
10001048
Expr::AggregateFunction {
10011049
ref fun, ref args, ..
10021050
} => {
@@ -1178,6 +1226,32 @@ impl Into<protobuf::Schema> for &Schema {
11781226
}
11791227
}
11801228

1229+
impl From<&AggregateFunction> for protobuf::AggregateFunction {
1230+
fn from(value: &AggregateFunction) -> Self {
1231+
match value {
1232+
AggregateFunction::Min => Self::Min,
1233+
AggregateFunction::Max => Self::Max,
1234+
AggregateFunction::Sum => Self::Sum,
1235+
AggregateFunction::Avg => Self::Avg,
1236+
AggregateFunction::Count => Self::Count,
1237+
}
1238+
}
1239+
}
1240+
1241+
impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction {
1242+
fn from(value: &BuiltInWindowFunction) -> Self {
1243+
match value {
1244+
BuiltInWindowFunction::FirstValue => Self::FirstValue,
1245+
BuiltInWindowFunction::LastValue => Self::LastValue,
1246+
BuiltInWindowFunction::RowNumber => Self::RowNumber,
1247+
BuiltInWindowFunction::Rank => Self::Rank,
1248+
BuiltInWindowFunction::Lag => Self::Lag,
1249+
BuiltInWindowFunction::Lead => Self::Lead,
1250+
BuiltInWindowFunction::DenseRank => Self::DenseRank,
1251+
}
1252+
}
1253+
}
1254+
11811255
impl TryFrom<&arrow::datatypes::DataType> for protobuf::ScalarType {
11821256
type Error = BallistaError;
11831257
fn try_from(value: &arrow::datatypes::DataType) -> Result<Self, Self::Error> {

datafusion/src/logical_plan/builder.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,31 @@ impl LogicalPlanBuilder {
289289
}))
290290
}
291291

292+
/// Apply a window: partition by the `partition_by_expr` expressions
293+
/// and calculating the window spec.
294+
pub fn window(
295+
&self,
296+
partition_by_expr: impl IntoIterator<Item = Expr>,
297+
order_by_expr: impl IntoIterator<Item = Expr>,
298+
) -> Result<Self> {
299+
let partition_by_expr = partition_by_expr.into_iter().collect::<Vec<Expr>>();
300+
let order_by_expr = order_by_expr.into_iter().collect::<Vec<Expr>>();
301+
302+
let all_expr = partition_by_expr.iter().chain(order_by_expr.iter());
303+
304+
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
305+
306+
let window_schema =
307+
DFSchema::new(exprlist_to_fields(all_expr, self.plan.schema())?)?;
308+
309+
Ok(Self::from(&LogicalPlan::Window {
310+
input: Arc::new(self.plan.clone()),
311+
partition_by_expr,
312+
order_by_expr,
313+
schema: DFSchemaRef::new(window_schema),
314+
}))
315+
}
316+
292317
/// Apply an aggregate: grouping on the `group_expr` expressions
293318
/// and calculating `aggr_expr` aggregates for each distinct
294319
/// value of the `group_expr`;

0 commit comments

Comments
 (0)