Skip to content

Commit

Permalink
Determine causal window frames to produce early results. (#8842)
Browse files Browse the repository at this point in the history
* add handling for primary key window frame

* Add check for window end

* Add uniqueness check

* Minor changes

* Update signature of WindowFrame::new

* Make is_causal window state

* Minor changes

* Address reviews

* Minor changes

* Add new test

* Minor changes

* Minor changes

* Remove string handling

* Review Part 2

* Improve comments

---------

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
  • Loading branch information
mustafasrepo and ozankabak authored Jan 15, 2024
1 parent aff7094 commit 259e12c
Show file tree
Hide file tree
Showing 23 changed files with 382 additions and 149 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
1 change: 0 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,6 @@ impl ScalarValue {

/// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
assert!(datatype.is_primitive());
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(Some(false)),
DataType::Int8 => ScalarValue::Int8(Some(0)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ mod tests {
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub fn bounded_window_exec(
&[col(col_name, &schema).unwrap()],
&[],
&sort_exprs,
Arc::new(WindowFrame::new(true)),
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
)
.unwrap()],
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ async fn test_count_wildcard_on_window() -> Result<()> {
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
},
WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
))])?
.explain(false, false)?
.collect()
Expand Down
95 changes: 85 additions & 10 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,89 @@ async fn window_bounded_window_random_comparison() -> Result<()> {
Ok(())
}

// This tests whether we can generate bounded window results for each input
// batch immediately for causal window frames.
#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn bounded_window_causal_non_causal() -> Result<()> {
let session_config = SessionConfig::new();
let ctx = SessionContext::new_with_config(session_config);
let mut batches = make_staggered_batches::<true>(1000, 10, 23_u64);
// Remove empty batches:
batches.retain(|batch| batch.num_rows() > 0);
let schema = batches[0].schema();
let memory_exec = Arc::new(MemoryExec::try_new(
&[batches.clone()],
schema.clone(),
None,
)?);
let window_fn = WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
let fn_name = "COUNT".to_string();
let args = vec![col("x", &schema)?];
let partitionby_exprs = vec![];
let orderby_exprs = vec![];
// Window frame starts with "UNBOUNDED PRECEDING":
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));

// Simulate cases of the following form:
// COUNT(x) OVER (
// ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> PRECEDING/FOLLOWING
// )
for is_preceding in [false, true] {
for end_bound in [0, 1, 2, 3] {
let end_bound = if is_preceding {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
} else {
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
};
let window_frame = WindowFrame::new_bounds(
WindowFrameUnits::Rows,
start_bound.clone(),
end_bound,
);
let causal = window_frame.is_causal();

let window_expr = create_window_expr(
&window_fn,
fn_name.clone(),
&args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
memory_exec.clone(),
vec![],
InputOrderMode::Linear,
)?);
let task_ctx = ctx.task_ctx();
let mut collected_results = collect(running_window_exec, task_ctx).await?;
collected_results.retain(|batch| batch.num_rows() > 0);
let input_batch_sizes = batches
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
let result_batch_sizes = collected_results
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
if causal {
// For causal window frames, we can generate results immediately
// for each input batch. Hence, batch sizes should match.
assert_eq!(input_batch_sizes, result_batch_sizes);
} else {
// For non-causal window frames, we cannot generate results
// immediately for each input batch. Hence, batch sizes shouldn't
// match.
assert_ne!(input_batch_sizes, result_batch_sizes);
}
}
}

Ok(())
}

fn get_random_function(
schema: &SchemaRef,
rng: &mut StdRng,
Expand Down Expand Up @@ -343,11 +426,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down Expand Up @@ -375,11 +454,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
end_bound.val as u64,
)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ where
/// vec![col("speed")], // smooth_it(speed)
/// vec![col("car")], // PARTITION BY car
/// vec![col("time").sort(true, true)], // ORDER BY time ASC
/// WindowFrame::new(false),
/// WindowFrame::new(None),
/// );
/// ```
pub trait WindowUDFImpl: Debug + Send + Sync {
Expand Down
20 changes: 10 additions & 10 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,28 +1252,28 @@ mod tests {
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
Expand All @@ -1295,28 +1295,28 @@ mod tests {
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
Expand Down Expand Up @@ -1350,7 +1350,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
Expand All @@ -1361,7 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
];
let expected = vec![
Expand Down
Loading

0 comments on commit 259e12c

Please sign in to comment.