Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ahash = "0.7"
hashbrown = "0.11"
arrow = { git = "https://github.com/cube-js/arrow-rs.git", branch = "cube", features = ["prettyprint"] }
parquet = { git = "https://github.com/cube-js/arrow-rs.git", branch = "cube", features = ["arrow"] }
sqlparser = { git = "https://github.com/cube-js/sqlparser-rs.git", rev = "2fcd06f7354e8c85f170b49a08fc018749289a40" }
sqlparser = { git = "https://github.com/cube-js/sqlparser-rs.git", rev = "b1d144a2cb5cc47ac950fd1d518bc28b4dc33ab9" }
paste = "^1.0"
num_cpus = "1.13.0"
chrono = "0.4"
Expand Down
107 changes: 59 additions & 48 deletions datafusion/src/cube_ext/rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use async_trait::async_trait;
use chrono::{TimeZone, Utc};
use hashbrown::HashMap;
use itertools::Itertools;
use sqlparser::ast::RollingOffset;
use std::any::Any;
use std::cmp::{max, Ordering};
use std::convert::TryFrom;
Expand Down Expand Up @@ -226,7 +227,12 @@ impl ExtensionPlanner for Planner {
.iter()
.map(|e| -> Result<_, DataFusionError> {
match e {
Expr::RollingAggregate { agg, start, end } => {
Expr::RollingAggregate {
agg,
start,
end,
offset,
} => {
let start = frame_bound_to_diff(start, dimension_type)?;
let end = frame_bound_to_diff(end, dimension_type)?;
let agg = planner.create_aggregate_expr(
Expand All @@ -239,6 +245,10 @@ impl ExtensionPlanner for Planner {
agg,
lower_bound: start,
upper_bound: end,
offset_to_end: match offset {
RollingOffset::Start => false,
RollingOffset::End => true,
},
})
}
_ => panic!("expected ROLLING() aggregate, got {:?}", e),
Expand Down Expand Up @@ -341,6 +351,8 @@ pub struct RollingAgg {
/// The bound is inclusive.
pub upper_bound: Option<ScalarValue>,
pub agg: Arc<dyn AggregateExpr>,
/// When true, all calculations must be done for the last point in the interval.
pub offset_to_end: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -450,7 +462,12 @@ impl ExecutionPlan for RollingWindowAggExec {
.group_by_dimension
.as_ref()
.map(|d| -> Result<_, DataFusionError> {
Ok(d.evaluate(&input)?.into_array(num_rows))
let mut d = d.evaluate(&input)?.into_array(num_rows);
if d.data_type() != &dim_iter_type {
// This is to upcast timestamps to nanosecond precision.
d = arrow::compute::cast(&d, &dim_iter_type)?;
}
Ok(d)
})
.transpose()?;
let extra_aggs_inputs = self
Expand Down Expand Up @@ -495,6 +512,11 @@ impl ExecutionPlan for RollingWindowAggExec {
// Avoid running indefinitely due to all kinds of errors.
let mut window_start = group_start;
let mut window_end = group_start;
let offset_to_end = if r.offset_to_end {
Some(&self.every)
} else {
None
};

let mut d = self.from.clone();
let mut d_iter = 0;
Expand All @@ -505,6 +527,7 @@ impl ExecutionPlan for RollingWindowAggExec {
.unwrap(),
&d,
r.lower_bound.as_ref(),
offset_to_end,
)
{
window_start += 1;
Expand All @@ -515,6 +538,7 @@ impl ExecutionPlan for RollingWindowAggExec {
&ScalarValue::try_from_array(&dimension, window_end).unwrap(),
&d,
r.upper_bound.as_ref(),
offset_to_end,
)
{
window_end += 1;
Expand Down Expand Up @@ -747,10 +771,38 @@ fn compute_agg_inputs(
.collect()
}

/// Returns `(value, current+bounds)` pair that can be used for comparison to check window bounds.
fn prepare_bound_compare(
value: &ScalarValue,
current: &ScalarValue,
bound: &ScalarValue,
offset_to_end: Option<&ScalarValue>,
) -> (i64, i64) {
let mut added = add_dim(current, bound);
if let Some(offset) = offset_to_end {
added = add_dim(&added, offset)
}

let (mut added, value) = match (added, value) {
(ScalarValue::Int64(Some(a)), ScalarValue::Int64(Some(v))) => (a, v),
(
ScalarValue::TimestampNanosecond(Some(a)),
ScalarValue::TimestampNanosecond(Some(v)),
) => (a, v),
(a, v) => panic!("unsupported values in rolling window: ({:?}, {:?})", a, v),
};

if offset_to_end.is_some() {
added -= 1
}
(*value, added)
}

fn meets_lower_bound(
value: &ScalarValue,
current: &ScalarValue,
bound: Option<&ScalarValue>,
offset_to_end: Option<&ScalarValue>,
) -> bool {
let bound = match bound {
Some(p) => p,
Expand All @@ -761,35 +813,15 @@ fn meets_lower_bound(
if value.is_null() {
return false;
}
match (current, value, bound) {
(
ScalarValue::Int64(Some(current)),
ScalarValue::Int64(Some(value)),
ScalarValue::Int64(Some(bound)),
) => return current + *bound <= *value,
(
ScalarValue::TimestampNanosecond(Some(_)),
ScalarValue::TimestampNanosecond(Some(value)),
ScalarValue::IntervalYearMonth(Some(_))
| ScalarValue::IntervalDayTime(Some(_)),
) => {
let added = match add_dim(current, bound) {
ScalarValue::TimestampNanosecond(Some(v)) => v,
o => panic!("expected timestamp, got {}", o),
};
return added <= *value;
}
_ => panic!(
"unsupported values in rolling window: ({:?}, {:?}, {:?})",
current, value, bound
),
}
let (value, added) = prepare_bound_compare(value, current, bound, offset_to_end);
added <= value
}

fn meets_upper_bound(
value: &ScalarValue,
current: &ScalarValue,
bound: Option<&ScalarValue>,
offset_to_end: Option<&ScalarValue>,
) -> bool {
let bound = match bound {
Some(p) => p,
Expand All @@ -800,29 +832,8 @@ fn meets_upper_bound(
if value.is_null() {
return false;
}
match (current, value, bound) {
(
ScalarValue::Int64(Some(current)),
ScalarValue::Int64(Some(value)),
ScalarValue::Int64(Some(bound)),
) => return *value <= current + *bound,
(
ScalarValue::TimestampNanosecond(Some(_)),
ScalarValue::TimestampNanosecond(Some(value)),
ScalarValue::IntervalYearMonth(Some(_))
| ScalarValue::IntervalDayTime(Some(_)),
) => {
let added = match add_dim(current, bound) {
ScalarValue::TimestampNanosecond(Some(v)) => v,
o => panic!("expected timestamp, got {}", o),
};
return *value <= added;
}
_ => panic!(
"unsupported values in rolling window: ({:?}, {:?}, {:?})",
current, value, bound
),
}
let (value, added) = prepare_bound_compare(value, current, bound, offset_to_end);
value <= added
}

fn expect_non_null_scalar(
Expand Down
19 changes: 16 additions & 3 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use arrow::{compute::can_cast_types, datatypes::DataType};
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use serde_derive::Deserialize;
use serde_derive::Serialize;
use sqlparser::ast::RollingOffset;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::fmt;
Expand Down Expand Up @@ -357,6 +358,8 @@ pub enum Expr {
start: window_frames::WindowFrameBound,
/// End
end: window_frames::WindowFrameBound,
/// Offset
offset: RollingOffset,
},
/// Returns whether the list contains the expr value.
InList {
Expand Down Expand Up @@ -935,10 +938,12 @@ impl Expr {
agg,
start: start_bound,
end: end_bound,
offset,
} => Expr::RollingAggregate {
agg: rewrite_boxed(agg, rewriter)?,
start: start_bound,
end: end_bound,
offset,
},
Expr::Wildcard => Expr::Wildcard,
};
Expand Down Expand Up @@ -1669,9 +1674,11 @@ impl fmt::Debug for Expr {
agg,
start: start_bound,
end: end_bound,
offset,
} => {
write!(f, "ROLLING({:?} RANGE", agg)?;
write!(f, " BETWEEN {} AND {}", start_bound, end_bound)?;
write!(f, " OFFSET {}", offset)?;
write!(f, ")")
}
Expr::InList {
Expand Down Expand Up @@ -1805,11 +1812,17 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::RollingAggregate { agg, start, end } => Ok(format!(
"ROLLING({} RANGE BETWEEN {} AND {})",
Expr::RollingAggregate {
agg,
start,
end,
offset,
} => Ok(format!(
"ROLLING({} RANGE BETWEEN {} AND {} OFFSET {})",
create_name(agg, input_schema)?,
start,
end
end,
offset,
)),
Expr::InList {
expr,
Expand Down
4 changes: 3 additions & 1 deletion datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,15 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
}
Expr::InList { .. } => Ok(expr.clone()),
Expr::RollingAggregate {
agg: _,
start: start_bound,
end: end_bound,
..
offset,
} => Ok(Expr::RollingAggregate {
agg: Box::new(expressions[0].clone()),
start: start_bound.clone(),
end: end_bound.clone(),
offset: *offset,
}),
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
Expand Down
8 changes: 5 additions & 3 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ use arrow::datatypes::*;
use hashbrown::HashMap;
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg,
Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, Query, Select,
SelectItem, SetExpr, SetOperator, ShowStatementFilter, TableFactor, TableWithJoins,
UnaryOperator, Value,
Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset, Query, RollingOffset,
Select, SelectItem, SetExpr, SetOperator, ShowStatementFilter, TableFactor,
TableWithJoins, UnaryOperator, Value,
};
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{OrderByExpr, Statement};
Expand Down Expand Up @@ -1482,6 +1482,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
agg,
first_bound,
second_bound,
offset,
} => {
let agg = match self.sql_expr_to_logical_expr(&agg, schema)? {
e @ Expr::AggregateFunction { .. }
Expand All @@ -1506,6 +1507,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
agg: Box::new(agg),
start,
end,
offset: (*offset).unwrap_or(RollingOffset::Start),
})
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/src/sql/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,12 @@ where
agg,
start: start_bound,
end: end_bound,
offset,
} => Ok(Expr::RollingAggregate {
agg: Box::new(clone_with_replacement(agg, replacement_fn)?),
start: start_bound.clone(),
end: end_bound.clone(),
offset: *offset,
}),
Expr::Wildcard => Ok(Expr::Wildcard),
},
Expand Down