Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 250 additions & 1 deletion src/distributed_planner/plan_annotator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ fn _annotate_plan(
task_count,
plan,
};
if annotated_plan.required_network_boundary.is_none() {
if !(root || annotated_plan.required_network_boundary.is_some()) {
return Ok(annotated_plan);
};

Expand Down Expand Up @@ -299,3 +299,252 @@ fn required_network_boundary_below(parent: &dyn ExecutionPlan) -> Option<Require

None
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::in_memory_channel_resolver::InMemoryChannelResolver;
use crate::test_utils::parquet::register_parquet_tables;
use crate::{DistributedExt, assert_snapshot};
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::{SessionConfig, SessionContext};
use itertools::Itertools;

/* shema for the "weather" table

MinTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
MaxTemp [type=DOUBLE] [repetitiontype=OPTIONAL]
Rainfall [type=DOUBLE] [repetitiontype=OPTIONAL]
Evaporation [type=DOUBLE] [repetitiontype=OPTIONAL]
Sunshine [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindGustDir [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindGustSpeed [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindDir9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindDir3pm [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindSpeed9am [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
WindSpeed3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
Humidity9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
Humidity3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
Pressure9am [type=DOUBLE] [repetitiontype=OPTIONAL]
Pressure3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
Cloud9am [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
Cloud3pm [type=INT64] [convertedtype=INT_64] [repetitiontype=OPTIONAL]
Temp9am [type=DOUBLE] [repetitiontype=OPTIONAL]
Temp3pm [type=DOUBLE] [repetitiontype=OPTIONAL]
RainToday [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
RISK_MM [type=DOUBLE] [repetitiontype=OPTIONAL]
RainTomorrow [type=BYTE_ARRAY] [convertedtype=UTF8] [repetitiontype=OPTIONAL]
*/

#[tokio::test]
async fn test_select_all() {
let query = r#"
SELECT * FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @"DataSourceExec: task_count=Desired(3)")
}

#[tokio::test]
async fn test_aggregation() {
let query = r#"
SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*)
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ProjectionExec: task_count=Maximum(1)
SortPreservingMergeExec: task_count=Maximum(1), required_network_boundary=Coalesce
SortExec: task_count=Desired(2)
ProjectionExec: task_count=Desired(2)
AggregateExec: task_count=Desired(2)
CoalesceBatchesExec: task_count=Desired(2), required_network_boundary=Shuffle
RepartitionExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

#[tokio::test]
async fn test_left_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp" FROM weather a LEFT JOIN weather b ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
CoalesceBatchesExec: task_count=Maximum(1)
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}

#[tokio::test]
async fn test_left_join_distributed() {
let query = r#"
WITH a AS (
SELECT
AVG("MinTemp") as "MinTemp",
"RainTomorrow"
FROM weather
WHERE "RainToday" = 'yes'
GROUP BY "RainTomorrow"
), b AS (
SELECT
AVG("MaxTemp") as "MaxTemp",
"RainTomorrow"
FROM weather
WHERE "RainToday" = 'no'
GROUP BY "RainTomorrow"
)
SELECT
a."MinTemp",
b."MaxTemp"
FROM a
LEFT JOIN b
ON a."RainTomorrow" = b."RainTomorrow"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
CoalesceBatchesExec: task_count=Maximum(1)
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1), required_network_boundary=Coalesce
ProjectionExec: task_count=Desired(2)
AggregateExec: task_count=Desired(2)
CoalesceBatchesExec: task_count=Desired(2), required_network_boundary=Shuffle
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
CoalesceBatchesExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
ProjectionExec: task_count=Maximum(1)
AggregateExec: task_count=Maximum(1)
CoalesceBatchesExec: task_count=Maximum(1), required_network_boundary=Shuffle
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
CoalesceBatchesExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

#[tokio::test]
async fn test_inner_join() {
let query = r#"
SELECT a."MinTemp", b."MaxTemp" FROM weather a INNER JOIN weather b ON a."RainToday" = b."RainToday"
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
CoalesceBatchesExec: task_count=Maximum(1)
HashJoinExec: task_count=Maximum(1)
CoalescePartitionsExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
DataSourceExec: task_count=Maximum(1)
")
}

#[tokio::test]
async fn test_distinct() {
let query = r#"
SELECT DISTINCT "RainToday" FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
AggregateExec: task_count=Desired(2)
CoalesceBatchesExec: task_count=Desired(2), required_network_boundary=Shuffle
RepartitionExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
AggregateExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

#[tokio::test]
async fn test_union_all() {
let query = r#"
SELECT "MinTemp" FROM weather WHERE "RainToday" = 'yes'
UNION ALL
SELECT "MaxTemp" FROM weather WHERE "RainToday" = 'no'
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
UnionExec: task_count=Desired(3)
CoalesceBatchesExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
ProjectionExec: task_count=Desired(3)
CoalesceBatchesExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

#[tokio::test]
async fn test_subquery() {
let query = r#"
SELECT * FROM (
SELECT "MinTemp", "MaxTemp" FROM weather WHERE "RainToday" = 'yes'
) AS subquery WHERE "MinTemp" > 5
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
CoalesceBatchesExec: task_count=Desired(3)
FilterExec: task_count=Desired(3)
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

#[tokio::test]
async fn test_window_function() {
let query = r#"
SELECT "MinTemp", ROW_NUMBER() OVER (PARTITION BY "RainToday" ORDER BY "MinTemp") as rn
FROM weather
"#;
let annotated = sql_to_annotated(query).await;
assert_snapshot!(annotated, @r"
ProjectionExec: task_count=Desired(3)
BoundedWindowAggExec: task_count=Desired(3)
SortExec: task_count=Desired(3)
CoalesceBatchesExec: task_count=Desired(3), required_network_boundary=Shuffle
RepartitionExec: task_count=Desired(3)
DataSourceExec: task_count=Desired(3)
")
}

async fn sql_to_annotated(query: &str) -> String {
let config = SessionConfig::new()
.with_target_partitions(4)
.with_information_schema(true);

let state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
.with_distributed_channel_resolver(InMemoryChannelResolver::new(4))
.build();

let ctx = SessionContext::new_with_state(state);
let mut queries = query.split(";").collect_vec();
let last_query = queries.pop().unwrap();

for query in queries {
ctx.sql(query).await.unwrap();
}

register_parquet_tables(&ctx).await.unwrap();

let df = ctx.sql(last_query).await.unwrap();

let annotated = annotate_plan(
df.create_physical_plan().await.unwrap(),
ctx.state_ref().read().config_options().as_ref(),
)
.expect("failed to annotate plan");
format!("{annotated:?}")
}
}