Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ message LogicalExprNode {

ScalarValue literal = 3;


// binary expressions
BinaryExprNode binary_expr = 4;

Expand All @@ -60,6 +59,9 @@ message LogicalExprNode {
bool wildcard = 15;
ScalarFunctionNode scalar_function = 16;
TryCastNode try_cast = 17;

// window expressions
WindowExprNode window_expr = 18;
}
}

Expand Down Expand Up @@ -151,6 +153,29 @@ message AggregateExprNode {
LogicalExprNode expr = 2;
}

enum BuiltInWindowFunction {
ROW_NUMBER = 0;
RANK = 1;
DENSE_RANK = 2;
PERCENT_RANK = 3;
CUME_DIST = 4;
NTILE = 5;
LAG = 6;
LEAD = 7;
FIRST_VALUE = 8;
LAST_VALUE = 9;
NTH_VALUE = 10;
}

message WindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked whether this makes sense to reuse aggregate functions for window expressions - I think it does! E.g. PostgreSQL also says:
https://www.postgresql.org/docs/9.3/functions-window.html

In addition to these functions, any built-in or user-defined aggregate function can be used as a window function (see Section 9.20 for a list of the built-in aggregates). Aggregate functions act as window functions only when an OVER clause follows the call; otherwise they act as regular aggregates.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes in general three types of things can be used:

  1. aggregation
  2. UDAF
  3. built in window function

for both 1. and 2. they are not order sensitive, but for 3 we'll have to take sort into account

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[postgres] # explain select c1, count(c3) over (partition by c1 order by c3) from test;
                            QUERY PLAN
------------------------------------------------------------------
 WindowAgg  (cost=6.32..8.32 rows=100 width=12)
   ->  Sort  (cost=6.32..6.57 rows=100 width=4)
         Sort Key: c1, c3
         ->  Seq Scan on test  (cost=0.00..3.00 rows=100 width=4)
(4 rows)
[postgres] # explain select c1, first_value(c3) over (partition by c1 order by c3) from test;
                            QUERY PLAN
------------------------------------------------------------------
 WindowAgg  (cost=6.32..8.32 rows=100 width=6)
   ->  Sort  (cost=6.32..6.57 rows=100 width=4)
         Sort Key: c1, c3
         ->  Seq Scan on test  (cost=0.00..3.00 rows=100 width=4)
(4 rows)

IMO only the second time we'll really need to sort by c3

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also fun thing to notice:

[postgres] # explain analyze select c1, sum(c3) over (partition by c1 order by c3), avg(c3) over (partition by c1 order by c3 desc) from test;
                                                        QUERY PLAN
--------------------------------------------------------------------------------------------------------------------------
 WindowAgg  (cost=11.64..13.64 rows=100 width=44) (actual time=1.287..1.373 rows=100 loops=1)
   ->  Sort  (cost=11.64..11.89 rows=100 width=36) (actual time=1.281..1.292 rows=100 loops=1)
         Sort Key: c1, c3
         Sort Method: quicksort  Memory: 31kB
         ->  WindowAgg  (cost=6.32..8.32 rows=100 width=36) (actual time=1.051..1.174 rows=100 loops=1)
               ->  Sort  (cost=6.32..6.57 rows=100 width=4) (actual time=0.221..0.231 rows=100 loops=1)
                     Sort Key: c1, c3 DESC
                     Sort Method: quicksort  Memory: 29kB
                     ->  Seq Scan on test  (cost=0.00..3.00 rows=100 width=4) (actual time=0.010..0.028 rows=100 loops=1)
 Planning Time: 0.087 ms
 Execution Time: 1.437 ms
(11 rows)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked whether this makes sense to reuse aggregate functions for window expressions - I think it does! E.g. PostgreSQL also says:
https://www.postgresql.org/docs/9.3/functions-window.html

In addition to these functions, any built-in or user-defined aggregate function can be used as a window function (see Section 9.20 for a list of the built-in aggregates). Aggregate functions act as window functions only when an OVER clause follows the call; otherwise they act as regular aggregates.

it is very useful for analytics, e.g. if you want to know (in an employee table with name, department, and salary) the list of employees in each department with salary above average.

Copy link
Contributor

@Dandandan Dandandan May 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe count(..) over order by .. also needs to be sorted, it will do a count over the window, which means a running count (over sorted rows) by default.
But yeah very useful for analytics indeed 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment above was more about re-using the same functions over here - as I thought we might not want to support every aggregation function here too. But for me it sounds like a good idea to reuse them. Maybe @alamb has some ideas about it as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SQL is confusing in this area -- as @jimexist says, all "normal" aggregate functions (e.g. sum, count, etc) are also valid window functions, but the reverse is not true. You can't use window functions (e.g. LAG, LEAD, etc) outside of a window clause.

Thus I think representing window functions as a new type of function, as this PR does, makes the most sense. They are different enough (e.g. require information on the incoming windows) that trying to wrangle them into the same structures as normal aggregates seems like it will get messy. Long term I would expect we have a UDWF (user defined window function) api as well.

Ideally the physical implementation for sum / count / etc can be mostly reused but in the plans I think they are different enough to warrant different plan structures.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe count(..) over order by .. also needs to be sorted, it will do a count over the window, which means a running count (over sorted rows) by default.

Good point! Indeed.

BuiltInWindowFunction built_in_function = 2;
// udaf = 3
}
LogicalExprNode expr = 4;
}

message BetweenNode {
LogicalExprNode expr = 1;
bool negated = 2;
Expand Down Expand Up @@ -200,6 +225,7 @@ message LogicalPlanNode {
EmptyRelationNode empty_relation = 10;
CreateExternalTableNode create_external_table = 11;
ExplainNode explain = 12;
WindowNode window = 13;
}
}

Expand Down Expand Up @@ -288,6 +314,50 @@ message AggregateNode {
repeated LogicalExprNode aggr_expr = 3;
}

message WindowNode {
LogicalPlanNode input = 1;
repeated LogicalExprNode window_expr = 2;
repeated LogicalExprNode partition_by_expr = 3;
repeated LogicalExprNode order_by_expr = 4;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof window_frame {
WindowFrame frame = 5;
}
// TODO add filter by expr
}

enum WindowFrameUnits {
ROWS = 0;
RANGE = 1;
GROUPS = 2;
}

message WindowFrame {
WindowFrameUnits window_frame_units = 1;
WindowFrameBound start_bound = 2;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof end_bound {
WindowFrameBound bound = 3;
}
}

enum WindowFrameBoundType {
CURRENT_ROW = 0;
PRECEDING = 1;
FOLLOWING = 2;
}

message WindowFrameBound {
WindowFrameBoundType window_frame_bound_type = 1;
// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
// this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
oneof bound_value {
uint64 value = 2;
}
}

enum JoinType {
INNER = 0;
LEFT = 1;
Expand Down Expand Up @@ -334,6 +404,7 @@ message PhysicalPlanNode {
MergeExecNode merge = 14;
UnresolvedShuffleExecNode unresolved = 15;
RepartitionExecNode repartition = 16;
WindowAggExecNode window = 17;
}
}

Expand Down Expand Up @@ -399,6 +470,13 @@ enum AggregateMode {
FINAL_PARTITIONED = 2;
}

message WindowAggExecNode {
PhysicalPlanNode input = 1;
repeated LogicalExprNode window_expr = 2;
repeated string window_expr_name = 3;
Schema input_schema = 4;
}

message HashAggregateExecNode {
repeated LogicalExprNode group_expr = 1;
repeated LogicalExprNode aggr_expr = 2;
Expand Down
197 changes: 186 additions & 11 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

//! Serde code to convert from protocol buffers to Rust data structures.

use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::{
convert::{From, TryInto},
unimplemented,
};

use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};

use arrow::datatypes::{DataType, Field, Schema};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
use datafusion::scalar::ScalarValue;
use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
Expand Down Expand Up @@ -75,6 +76,34 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.build()
.map_err(|e| e.into())
}
LogicalPlanType::Window(window) => {
let input: LogicalPlan = convert_box_required!(window.input)?;
let window_expr = window
.window_expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;

// let partition_by_expr = window
// .partition_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// let order_by_expr = window
// .order_by_expr
// .iter()
// .map(|expr| expr.try_into())
// .collect::<Result<Vec<_>, _>>()?;
// // FIXME: add filter by expr
// // FIXME: parse the window_frame data
// let window_frame = None;
LogicalPlanBuilder::from(&input)
.window(
window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/
)?
.build()
.map_err(|e| e.into())
}
LogicalPlanType::Aggregate(aggregate) => {
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
let group_expr = aggregate
Expand Down Expand Up @@ -871,7 +900,10 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
type Error = BallistaError;

fn try_into(self) -> Result<Expr, Self::Error> {
use datafusion::physical_plan::window_functions;
use protobuf::logical_expr_node::ExprType;
use protobuf::window_expr_node;
use protobuf::WindowExprNode;

let expr_type = self
.expr_type
Expand All @@ -889,6 +921,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
Ok(Expr::Literal(scalar_value))
}
ExprType::WindowExpr(expr) => {
let window_function = expr
.window_function
.as_ref()
.ok_or_else(|| proto_error("Received empty window function"))?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
.ok_or_else(|| {
proto_error(format!(
"Received an unknown aggregate window function: {}",
i
))
})?;

Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::AggregateFunction(
AggregateFunction::from(aggr_function),
),
args: vec![parse_required_expr(&expr.expr)?],
})
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
let built_in_function =
protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
|| {
proto_error(format!(
"Received an unknown built-in window function: {}",
i
))
},
)?;

Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
BuiltInWindowFunction::from(built_in_function),
),
args: vec![parse_required_expr(&expr.expr)?],
})
}
}
}
ExprType::AggregateExpr(expr) => {
let aggr_function =
protobuf::AggregateFunction::from_i32(expr.aggr_function)
Expand All @@ -898,13 +972,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
expr.aggr_function
))
})?;
let fun = match aggr_function {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
};
let fun = AggregateFunction::from(aggr_function);

Ok(Expr::AggregateFunction {
fun,
Expand Down Expand Up @@ -1152,6 +1220,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
}

use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
use datafusion::physical_plan::{aggregates, windows};
use datafusion::prelude::{
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
};
Expand Down Expand Up @@ -1202,3 +1271,109 @@ fn parse_optional_expr(
None => Ok(None),
}
}

impl From<protobuf::WindowFrameUnits> for WindowFrameUnits {
fn from(units: protobuf::WindowFrameUnits) -> Self {
match units {
protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range,
protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
}
}
}

impl TryFrom<protobuf::WindowFrameBound> for WindowFrameBound {
type Error = BallistaError;

fn try_from(bound: protobuf::WindowFrameBound) -> Result<Self, Self::Error> {
let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| {
proto_error(format!(
"Received a WindowFrameBound message with unknown WindowFrameBoundType {}",
bound.window_frame_bound_type
))
})?;
match bound_type {
protobuf::WindowFrameBoundType::CurrentRow => {
Ok(WindowFrameBound::CurrentRow)
}
protobuf::WindowFrameBoundType::Preceding => {
// FIXME implement bound value parsing
Ok(WindowFrameBound::Preceding(Some(1)))
}
protobuf::WindowFrameBoundType::Following => {
// FIXME implement bound value parsing
Ok(WindowFrameBound::Following(Some(1)))
}
}
}
}

impl TryFrom<protobuf::WindowFrame> for WindowFrame {
type Error = BallistaError;

fn try_from(window: protobuf::WindowFrame) -> Result<Self, Self::Error> {
let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units)
.ok_or_else(|| {
proto_error(format!(
"Received a WindowFrame message with unknown WindowFrameUnits {}",
window.window_frame_units
))
})?
.into();
let start_bound = window
.start_bound
.ok_or_else(|| {
proto_error(
"Received a WindowFrame message with no start_bound".to_owned(),
)
})?
.try_into()?;
// FIXME parse end bound
let end_bound = None;
Ok(WindowFrame {
units,
start_bound,
end_bound,
})
}
}

impl From<protobuf::AggregateFunction> for AggregateFunction {
fn from(aggr_function: protobuf::AggregateFunction) -> Self {
match aggr_function {
protobuf::AggregateFunction::Min => AggregateFunction::Min,
protobuf::AggregateFunction::Max => AggregateFunction::Max,
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
}
}
}

impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
match built_in_function {
protobuf::BuiltInWindowFunction::RowNumber => {
BuiltInWindowFunction::RowNumber
}
protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
protobuf::BuiltInWindowFunction::PercentRank => {
BuiltInWindowFunction::PercentRank
}
protobuf::BuiltInWindowFunction::DenseRank => {
BuiltInWindowFunction::DenseRank
}
protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
protobuf::BuiltInWindowFunction::FirstValue => {
BuiltInWindowFunction::FirstValue
}
protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
protobuf::BuiltInWindowFunction::LastValue => {
BuiltInWindowFunction::LastValue
}
}
}
}
Loading