Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine causal window frames to produce early results. #8842

Merged
merged 15 commits into from
Jan 15, 2024
2 changes: 1 addition & 1 deletion datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
1 change: 0 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,6 @@ impl ScalarValue {

/// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
assert!(datatype.is_primitive());
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(Some(false)),
DataType::Int8 => ScalarValue::Int8(Some(0)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ mod tests {
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub fn bounded_window_exec(
&[col(col_name, &schema).unwrap()],
&[],
&sort_exprs,
Arc::new(WindowFrame::new(true)),
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
)
.unwrap()],
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ async fn test_count_wildcard_on_window() -> Result<()> {
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
},
WindowFrame::new_bounds(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
))])?
.explain(false, false)?
.collect()
Expand Down
95 changes: 85 additions & 10 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,89 @@ async fn window_bounded_window_random_comparison() -> Result<()> {
Ok(())
}

// This tests whether we can generate bounded window results for each input
// batch immediately for causal window frames.
#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn bounded_window_causal_non_causal() -> Result<()> {
let session_config = SessionConfig::new();
let ctx = SessionContext::new_with_config(session_config);
let mut batches = make_staggered_batches::<true>(1000, 10, 23_u64);
// Remove empty batches:
batches.retain(|batch| batch.num_rows() > 0);
let schema = batches[0].schema();
let memory_exec = Arc::new(MemoryExec::try_new(
&[batches.clone()],
schema.clone(),
None,
)?);
let window_fn = WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
let fn_name = "COUNT".to_string();
let args = vec![col("x", &schema)?];
let partitionby_exprs = vec![];
let orderby_exprs = vec![];
// Window frame starts with "UNBOUNDED PRECEDING":
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));

// Simulate cases of the following form:
// COUNT(x) OVER (
// ROWS BETWEEN UNBOUNDED PRECEDING AND <end_bound> PRECEDING/FOLLOWING
// )
for is_preceding in [false, true] {
for end_bound in [0, 1, 2, 3] {
let end_bound = if is_preceding {
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(end_bound)))
} else {
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound)))
};
let window_frame = WindowFrame::new_bounds(
WindowFrameUnits::Rows,
start_bound.clone(),
end_bound,
);
let causal = window_frame.is_causal();

let window_expr = create_window_expr(
&window_fn,
fn_name.clone(),
&args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
memory_exec.clone(),
vec![],
InputOrderMode::Linear,
)?);
let task_ctx = ctx.task_ctx();
let mut collected_results = collect(running_window_exec, task_ctx).await?;
collected_results.retain(|batch| batch.num_rows() > 0);
let input_batch_sizes = batches
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
let result_batch_sizes = collected_results
.iter()
.map(|batch| batch.num_rows())
.collect::<Vec<_>>();
if causal {
// For causal window frames, we can generate results immediately
// for each input batch. Hence, batch sizes should match.
assert_eq!(input_batch_sizes, result_batch_sizes);
} else {
// For non-causal window frames, we cannot generate results
// immediately for each input batch. Hence, batch sizes shouldn't
// match.
assert_ne!(input_batch_sizes, result_batch_sizes);
}
}
}

Ok(())
}

fn get_random_function(
schema: &SchemaRef,
rng: &mut StdRng,
Expand Down Expand Up @@ -343,11 +426,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down Expand Up @@ -375,11 +454,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
end_bound.val as u64,
)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound);
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ where
/// vec![col("speed")], // smooth_it(speed)
/// vec![col("car")], // PARTITION BY car
/// vec![col("time").sort(true, true)], // ORDER BY time ASC
/// WindowFrame::new(false),
/// WindowFrame::new(None),
/// );
/// ```
pub trait WindowUDFImpl {
Expand Down
20 changes: 10 additions & 10 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,28 +1252,28 @@ mod tests {
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
Expand All @@ -1295,28 +1295,28 @@ mod tests {
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
Expand Down Expand Up @@ -1350,7 +1350,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
Expand All @@ -1361,7 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
];
let expected = vec![
Expand Down
Loading
Loading