Skip to content

Commit

Permalink
feat: Initial support for Window function (#599)
Browse files Browse the repository at this point in the history
* feat: initial support for Window function

Co-authored-by: comphead <comphead@ukr.net>

* fix style

* fix style

* address comments

* abs()->unsigned_abs()

* address comments

---------

Co-authored-by: comphead <comphead@ukr.net>
  • Loading branch information
huaxingao and comphead authored Jul 2, 2024
1 parent 917fd43 commit 0d2fcbc
Show file tree
Hide file tree
Showing 6 changed files with 589 additions and 2 deletions.
196 changes: 195 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
use std::{collections::HashMap, sync::Arc};

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
Expand Down Expand Up @@ -50,12 +52,17 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
};
use datafusion_expr::ScalarUDF;
use datafusion_expr::expr::find_df_window_func;
use datafusion_expr::{ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr::window::WindowExpr;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};

use crate::execution::spark_operator::lower_window_frame_bound::LowerFrameBoundStruct;
use crate::execution::spark_operator::upper_window_frame_bound::UpperFrameBoundStruct;
use crate::execution::spark_operator::WindowFrameType;
use crate::{
errors::ExpressionError,
execution::{
Expand Down Expand Up @@ -980,6 +987,47 @@ impl PhysicalPlanner {

Ok((scans, hash_join))
}
OpStruct::Window(wnd) => {
let (scans, child) = self.create_plan(&children[0], inputs)?;
let input_schema = child.schema();
let sort_exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = wnd
.order_by_list
.iter()
.map(|expr| self.create_sort_expr(expr, input_schema.clone()))
.collect();

let partition_exprs: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = wnd
.partition_by_list
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect();

let sort_exprs = &sort_exprs?;
let partition_exprs = &partition_exprs?;

let window_expr: Result<Vec<Arc<dyn WindowExpr>>, ExecutionError> = wnd
.window_expr
.iter()
.map(|expr| {
self.create_window_expr(
expr,
input_schema.clone(),
partition_exprs,
sort_exprs,
)
})
.collect();

Ok((
scans,
Arc::new(BoundedWindowAggExec::try_new(
window_expr?,
child,
partition_exprs.to_vec(),
InputOrderMode::Sorted,
)?),
))
}
}
}

Expand Down Expand Up @@ -1322,6 +1370,152 @@ impl PhysicalPlanner {
}
}

/// Create a DataFusion windows physical expression from Spark physical expression
fn create_window_expr<'a>(
&'a self,
spark_expr: &'a crate::execution::spark_operator::WindowExpr,
input_schema: SchemaRef,
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new());
if let Some(func) = &spark_expr.built_in_window_function {
match &func.expr_struct {
Some(ExprStruct::ScalarFunc(f)) => {
window_func_name.clone_from(&f.func);
window_func_args.clone_from(&f.args);
}
other => {
return Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
)))
}
};
} else if let Some(agg_func) = &spark_expr.agg_func {
let result = Self::process_agg_func(agg_func)?;
window_func_name = result.0;
window_func_args = result.1;
} else {
return Err(ExecutionError::GeneralError(
"Both func and agg_func are not set".to_string(),
));
}

let window_func = match find_df_window_func(&window_func_name) {
Some(f) => f,
_ => {
return Err(ExecutionError::GeneralError(format!(
"{window_func_name} not supported for window function"
)))
}
};

let window_args = window_func_args
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, ExecutionError>>()?;

let spark_window_frame = match spark_expr
.spec
.as_ref()
.and_then(|inner| inner.frame_specification.as_ref())
{
Some(frame) => frame,
_ => {
return Err(ExecutionError::DeserializeError(
"Cannot deserialize window frame".to_string(),
))
}
};

let units = match spark_window_frame.frame_type() {
WindowFrameType::Rows => WindowFrameUnits::Rows,
WindowFrameType::Range => WindowFrameUnits::Range,
};

let lower_bound: WindowFrameBound = match spark_window_frame
.lower_bound
.as_ref()
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.unsigned_abs() as u64;
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
};

let upper_bound: WindowFrameBound = match spark_window_frame
.upper_bound
.as_ref()
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);

datafusion::physical_plan::windows::create_window_expr(
&window_func,
window_func_name,
&window_args,
partition_by,
sort_exprs,
window_frame.into(),
&input_schema,
false, // TODO: Ignore nulls
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}

fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec<Expr>), ExecutionError> {
fn optional_expr_to_vec(expr_option: &Option<Expr>) -> Vec<Expr> {
expr_option
.as_ref()
.cloned()
.map_or_else(Vec::new, |e| vec![e])
}

fn int_to_stats_type(value: i32) -> Option<StatsType> {
match value {
0 => Some(StatsType::Sample),
1 => Some(StatsType::Population),
_ => None,
}
}

match &agg_func.expr_struct {
Some(AggExprStruct::Count(expr)) => {
let args = &expr.children;
Ok(("count".to_string(), args.to_vec()))
}
Some(AggExprStruct::Min(expr)) => {
Ok(("min".to_string(), optional_expr_to_vec(&expr.child)))
}
Some(AggExprStruct::Max(expr)) => {
Ok(("max".to_string(), optional_expr_to_vec(&expr.child)))
}
other => Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
))),
}
}

/// Create a DataFusion physical partitioning from Spark physical partitioning
fn create_partitioning(
&self,
Expand Down
59 changes: 59 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message Operator {
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
HashJoin hash_join = 109;
Window window = 110;
}
}

Expand Down Expand Up @@ -120,3 +121,61 @@ enum BuildSide {
BuildLeft = 0;
BuildRight = 1;
}

message WindowExpr {
spark.spark_expression.Expr built_in_window_function = 1;
spark.spark_expression.AggExpr agg_func = 2;
WindowSpecDefinition spec = 3;
}

enum WindowFrameType {
Rows = 0;
Range = 1;
}

message WindowFrame {
WindowFrameType frame_type = 1;
LowerWindowFrameBound lower_bound = 2;
UpperWindowFrameBound upper_bound = 3;
}

message LowerWindowFrameBound {
oneof lower_frame_bound_struct {
UnboundedPreceding unboundedPreceding = 1;
Preceding preceding = 2;
CurrentRow currentRow = 3;
}
}

message UpperWindowFrameBound {
oneof upper_frame_bound_struct {
UnboundedFollowing unboundedFollowing = 1;
Following following = 2;
CurrentRow currentRow = 3;
}
}

message Preceding {
int32 offset = 1;
}

message Following {
int32 offset = 1;
}

message UnboundedPreceding {}
message UnboundedFollowing {}
message CurrentRow {}

message WindowSpecDefinition {
repeated spark.spark_expression.Expr partitionSpec = 1;
repeated spark.spark_expression.Expr orderSpec = 2;
WindowFrame frameSpecification = 3;
}

message Window {
repeated WindowExpr window_expr = 1;
repeated spark.spark_expression.Expr order_by_list = 2;
repeated spark.spark_expression.Expr partition_by_list = 3;
Operator child = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -541,6 +542,17 @@ class CometSparkSessionExtensions
withInfo(s, Seq(info1, info2).flatten.mkString(","))
s

case w: WindowExec =>
val newOp = transform1(w)
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
}

case u: UnionExec
if isCometOperatorEnabled(conf, "union") &&
u.children.forall(isCometNative) =>
Expand Down
Loading

0 comments on commit 0d2fcbc

Please sign in to comment.