Skip to content

Commit

Permalink
Window frame GROUPS mode support (#4155)
Browse files Browse the repository at this point in the history
* Implementation of GROUPS mode in window frame

* Break down/flatten some functions, move comments inline

* Change find_next_group_and_start_index to use an exponentially growing search algorithm

* Removed the TODO after verification of correct implementation

* Window frame state capturing the state of the window frame calculations

* Do not use unnecessary traits and structs for window frame states

* Refactor to avoid carrying the window frame object around

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
  • Loading branch information
zembunia and ozankabak authored Nov 11, 2022
1 parent 225d62c commit 129654c
Show file tree
Hide file tree
Showing 10 changed files with 1,266 additions and 238 deletions.
29 changes: 24 additions & 5 deletions datafusion/common/src/bisect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,41 @@ pub fn bisect<const SIDE: bool>(
target: &[ScalarValue],
sort_options: &[SortOptions],
) -> Result<usize> {
let mut low: usize = 0;
let mut high: usize = item_columns
let low: usize = 0;
let high: usize = item_columns
.get(0)
.ok_or_else(|| {
DataFusionError::Internal("Column array shouldn't be empty".to_string())
})?
.len();
let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
let cmp = compare(current, target, sort_options)?;
Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
};
find_bisect_point(item_columns, target, compare_fn, low, high)
}

/// This function searches for a tuple of target values among the given rows using the bisection algorithm.
/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`),
/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively
/// bisect the input.
pub fn find_bisect_point<F>(
item_columns: &[ArrayRef],
target: &[ScalarValue],
compare_fn: F,
mut low: usize,
mut high: usize,
) -> Result<usize>
where
F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>,
{
while low < high {
let mid = ((high - low) / 2) + low;
let val = item_columns
.iter()
.map(|arr| ScalarValue::try_from_array(arr, mid))
.collect::<Result<Vec<ScalarValue>>>()?;
let cmp = compare(&val, target, sort_options)?;
let flag = if SIDE { cmp.is_lt() } else { cmp.is_le() };
if flag {
if compare_fn(&val, target)? {
low = mid + 1;
} else {
high = mid;
Expand Down
8 changes: 1 addition & 7 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
Expand Down Expand Up @@ -1457,12 +1457,6 @@ pub fn create_window_expr_with_name(
})
.collect::<Result<Vec<_>>>()?;
if let Some(ref window_frame) = window_frame {
if window_frame.units == WindowFrameUnits::Groups {
return Err(DataFusionError::NotImplemented(
"Window frame definitions involving GROUPS are not supported yet"
.to_string(),
));
}
if !is_window_valid(window_frame) {
return Err(DataFusionError::Execution(format!(
"Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
Expand Down
167 changes: 157 additions & 10 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,24 +1189,171 @@ async fn window_frame_ranges_unbounded_preceding_err() -> Result<()> {
}

#[tokio::test]
async fn window_frame_groups_query() -> Result<()> {
async fn window_frame_groups_preceding_following_desc() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT
SUM(c4) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING),
SUM(c3) OVER(ORDER BY c2 DESC GROUPS BETWEEN 10000 PRECEDING AND 10000 FOLLOWING),
COUNT(*) OVER(ORDER BY c2 DESC GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100
ORDER BY c9
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------------+----------------------------+-----------------+",
"| SUM(aggregate_test_100.c4) | SUM(aggregate_test_100.c3) | COUNT(UInt8(1)) |",
"+----------------------------+----------------------------+-----------------+",
"| 52276 | 781 | 56 |",
"| 260620 | 781 | 63 |",
"| -28623 | 781 | 37 |",
"| 260620 | 781 | 63 |",
"| 260620 | 781 | 63 |",
"+----------------------------+----------------------------+-----------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_order_by_null_desc() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
COUNT(c2) OVER (ORDER BY c1 DESC GROUPS BETWEEN 5 PRECEDING AND 3 FOLLOWING)
FROM null_cases
LIMIT 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----------------------+",
"| COUNT(null_cases.c2) |",
"+----------------------+",
"| 12 |",
"| 12 |",
"| 12 |",
"| 12 |",
"| 12 |",
"+----------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as a,
SUM(c1) OVER (ORDER BY c3 DESC GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as b,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as c,
SUM(c1) OVER (ORDER BY c3 DESC NULLS last GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as d,
SUM(c1) OVER (ORDER BY c3 DESC NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as e,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as f,
SUM(c1) OVER (ORDER BY c3 GROUPS current row) as a1,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a2,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a3,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a4,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a5,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as a6,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as a7,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 3 FOLLOWING AND UNBOUNDED FOLLOWING) as a8,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN current row AND UNBOUNDED FOLLOWING) as a9,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN current row AND 3 FOLLOWING) as a10,
SUM(c1) OVER (ORDER BY c3 GROUPS BETWEEN 5 FOLLOWING AND 7 FOLLOWING) as a11,
SUM(c1) OVER (ORDER BY c3 DESC GROUPS current row) as a21,
SUM(c1) OVER (ORDER BY c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a22,
SUM(c1) OVER (ORDER BY c3 DESC NULLS last GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a23,
SUM(c1) OVER (ORDER BY c3 NULLS last GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a24,
SUM(c1) OVER (ORDER BY c3 DESC NULLS first GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a25
FROM null_cases
ORDER BY c3
LIMIT 10";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
"| a | b | c | d | e | f | a1 | a2 | a3 | a4 | a5 | a6 | a7 | a8 | a9 | a10 | a11 | a21 | a22 | a23 | a24 | a25 |",
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
"| 412 | 307 | 412 | 307 | 307 | 412 | | | | 412 | | 4627 | 4627 | 4531 | 4627 | 115 | 85 | | | 4487 | 412 | 4627 |",
"| 488 | 339 | 488 | 339 | 339 | 488 | 72 | | | 488 | 72 | 4627 | 4627 | 4512 | 4627 | 140 | 153 | 72 | | 4473 | 488 | 4627 |",
"| 543 | 412 | 543 | 412 | 412 | 543 | 24 | | | 543 | 96 | 4627 | 4627 | 4487 | 4555 | 82 | 122 | 24 | | 4442 | 543 | 4555 |",
"| 553 | 488 | 553 | 488 | 488 | 553 | 19 | | | 553 | 115 | 4627 | 4555 | 4473 | 4531 | 89 | 114 | 19 | | 4402 | 553 | 4531 |",
"| 553 | 543 | 553 | 543 | 543 | 553 | 25 | | | 553 | 140 | 4627 | 4531 | 4442 | 4512 | 110 | 105 | 25 | | 4320 | 553 | 4512 |",
"| 591 | 553 | 591 | 553 | 553 | 591 | 14 | | | 591 | 154 | 4627 | 4512 | 4402 | 4487 | 167 | 181 | 14 | | 4320 | 591 | 4487 |",
"| 651 | 553 | 651 | 553 | 553 | 651 | 31 | 72 | 72 | 651 | 185 | 4627 | 4487 | 4320 | 4473 | 153 | 204 | 31 | 72 | 4288 | 651 | 4473 |",
"| 662 | 591 | 662 | 591 | 591 | 662 | 40 | 96 | 96 | 662 | 225 | 4627 | 4473 | 4320 | 4442 | 154 | 141 | 40 | 96 | 4215 | 662 | 4442 |",
"| 697 | 651 | 697 | 651 | 651 | 697 | 82 | 115 | 115 | 697 | 307 | 4627 | 4442 | 4288 | 4402 | 187 | 65 | 82 | 115 | 4139 | 697 | 4402 |",
"| 758 | 662 | 758 | 662 | 662 | 758 | | 140 | 140 | 758 | 307 | 4627 | 4402 | 4215 | 4320 | 181 | 48 | | 140 | 4084 | 758 | 4320 |",
"+-----+-----+-----+-----+-----+-----+----+-----+-----+-----+-----+------+------+------+------+-----+-----+-----+-----+------+-----+------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_multiple_order_columns() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_null_cases_csv(&ctx).await?;
let sql = "SELECT
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as a,
SUM(c1) OVER (ORDER BY c2, c3 DESC GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as b,
SUM(c1) OVER (ORDER BY c2, c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as c,
SUM(c1) OVER (ORDER BY c2, c3 DESC NULLS last GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as d,
SUM(c1) OVER (ORDER BY c2, c3 DESC NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as e,
SUM(c1) OVER (ORDER BY c2, c3 NULLS first GROUPS BETWEEN 9 PRECEDING AND 11 FOLLOWING) as f,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS current row) as a1,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 9 PRECEDING AND 5 PRECEDING) as a2,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 5 PRECEDING) as a3,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a4,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND current row) as a5,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as a6,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as a7,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 3 FOLLOWING AND UNBOUNDED FOLLOWING) as a8,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN current row AND UNBOUNDED FOLLOWING) as a9,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN current row AND 3 FOLLOWING) as a10,
SUM(c1) OVER (ORDER BY c2, c3 GROUPS BETWEEN 5 FOLLOWING AND 7 FOLLOWING) as a11
FROM null_cases
ORDER BY c3
LIMIT 10";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
"| a | b | c | d | e | f | a1 | a2 | a3 | a4 | a5 | a6 | a7 | a8 | a9 | a10 | a11 |",
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
"| 818 | 910 | 818 | 910 | 910 | 818 | | 249 | 249 | 818 | 432 | 4627 | 4234 | 4157 | 4195 | 98 | 82 |",
"| 537 | 979 | 537 | 979 | 979 | 537 | 72 | | | 537 | 210 | 4627 | 4569 | 4378 | 4489 | 169 | 55 |",
"| 811 | 838 | 811 | 838 | 838 | 811 | 24 | 221 | 3075 | 3665 | 3311 | 4627 | 1390 | 1276 | 1340 | 117 | 144 |",
"| 763 | 464 | 763 | 464 | 464 | 763 | 19 | 168 | 3572 | 4167 | 3684 | 4627 | 962 | 829 | 962 | 194 | 80 |",
"| 552 | 964 | 552 | 964 | 964 | 552 | 25 | | | 552 | 235 | 4627 | 4489 | 4320 | 4417 | 167 | 39 |",
"| 963 | 930 | 963 | 930 | 930 | 963 | 14 | 201 | 818 | 1580 | 1098 | 4627 | 3638 | 3455 | 3543 | 177 | 224 |",
"| 1113 | 814 | 1113 | 814 | 814 | 1113 | 31 | 415 | 2653 | 3351 | 2885 | 4627 | 1798 | 1694 | 1773 | 165 | 162 |",
"| 780 | 868 | 780 | 868 | 868 | 780 | 40 | 258 | 3143 | 3665 | 3351 | 4627 | 1340 | 1223 | 1316 | 117 | 102 |",
"| 740 | 466 | 740 | 466 | 466 | 740 | 82 | 164 | 3592 | 4168 | 3766 | 4627 | 962 | 768 | 943 | 244 | 122 |",
"| 772 | 832 | 772 | 832 | 832 | 772 | | 277 | 3189 | 3684 | 3351 | 4627 | 1316 | 1199 | 1276 | 119 | 64 |",
"+------+-----+------+-----+-----+------+----+-----+------+------+------+------+------+------+------+-----+-----+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn window_frame_groups_without_order_by() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
// execute the query
let df = ctx
.sql(
"SELECT
COUNT(c1) OVER (ORDER BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100;",
SUM(c4) OVER(PARTITION BY c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
FROM aggregate_test_100
ORDER BY c9;",
)
.await?;
let results = df.collect().await;
assert!(results
.as_ref()
.err()
.unwrap()
.to_string()
.contains("Window frame definitions involving GROUPS are not supported yet"));
let err = df.collect().await.unwrap_err();
assert_contains!(
err.to_string(),
"Execution error: GROUPS mode requires an ORDER BY clause".to_owned()
);
Ok(())
}

Expand Down
6 changes: 4 additions & 2 deletions datafusion/physical-expr/src/window/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ use datafusion_expr::WindowFrame;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};

use super::window_frame_state::WindowFrameContext;

/// A window expr that takes the form of an aggregate function
#[derive(Debug)]
pub struct AggregateWindowExpr {
Expand Down Expand Up @@ -114,13 +116,13 @@ impl WindowExpr for AggregateWindowExpr {
.map(|v| v.slice(partition_range.start, length))
.collect::<Vec<_>>();

let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
let mut last_range: (usize, usize) = (0, 0);

// We iterate on each row to perform a running calculation.
// First, cur_range is calculated, then it is compared with last_range.
for i in 0..length {
let cur_range = self.calculate_range(
&window_frame,
let cur_range = window_frame_ctx.calculate_range(
&slice_order_bys,
&sort_options,
length,
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-expr/src/window/built_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Physical exec for built-in window function expressions.

use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
Expand Down Expand Up @@ -113,10 +114,10 @@ impl WindowExpr for BuiltInWindowExpr {
.iter()
.map(|v| v.slice(partition_range.start, length))
.collect::<Vec<_>>();
let mut window_frame_ctx = WindowFrameContext::new(&window_frame);
// We iterate on each row to calculate window frame range and and window function result
for idx in 0..length {
let range = self.calculate_range(
&window_frame,
let range = window_frame_ctx.calculate_range(
&slice_order_bys,
&sort_options,
num_rows,
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub(crate) mod partition_evaluator;
pub(crate) mod rank;
pub(crate) mod row_number;
mod window_expr;
mod window_frame_state;

pub use aggregate::AggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
Expand Down
Loading

0 comments on commit 129654c

Please sign in to comment.