Skip to content

Commit 4d24c5b

Browse files
committed
resolve conflicts & add fetch to struct
1 parent 2315376 commit 4d24c5b

File tree

2 files changed

+155
-36
lines changed

2 files changed

+155
-36
lines changed

datafusion/core/tests/physical_optimizer/enforce_distribution.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,27 @@ use datafusion::datasource::file_format::file_compression_type::FileCompressionT
3131
use datafusion::datasource::listing::PartitionedFile;
3232
use datafusion::datasource::object_store::ObjectStoreUrl;
3333
use datafusion::datasource::physical_plan::{CsvSource, FileScanConfig, ParquetSource};
34+
use datafusion::execution::SessionStateBuilder;
35+
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
36+
use datafusion::prelude::SessionContext;
3437
use datafusion_common::error::Result;
3538
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3639
use datafusion_common::ScalarValue;
40+
use datafusion_execution::config::SessionConfig;
3741
use datafusion_expr::{JoinType, Operator};
3842
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
3943
use datafusion_physical_expr::PhysicalExpr;
4044
use datafusion_physical_expr::{
4145
expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr,
4246
};
4347
use datafusion_physical_expr_common::sort_expr::LexRequirement;
48+
use datafusion_physical_optimizer::coalesce_batches::CoalesceBatches;
4449
use datafusion_physical_optimizer::enforce_distribution::*;
4550
use datafusion_physical_optimizer::enforce_sorting::EnforceSorting;
51+
use datafusion_physical_optimizer::limit_pushdown::LimitPushdown;
4652
use datafusion_physical_optimizer::output_requirements::OutputRequirements;
53+
use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown;
54+
use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan;
4755
use datafusion_physical_optimizer::PhysicalOptimizerRule;
4856
use datafusion_physical_plan::aggregates::{
4957
AggregateExec, AggregateMode, PhysicalGroupBy,
@@ -62,6 +70,7 @@ use datafusion_physical_plan::union::UnionExec;
6270
use datafusion_physical_plan::ExecutionPlanProperties;
6371
use datafusion_physical_plan::PlanProperties;
6472
use datafusion_physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics};
73+
use futures::StreamExt;
6574

6675
/// Models operators like BoundedWindowExec that require an input
6776
/// ordering but is easy to construct
@@ -3154,3 +3163,77 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> {
31543163

31553164
Ok(())
31563165
}
3166+
3167+
#[tokio::test]
3168+
async fn apply_enforce_distribution_multiple_times() -> Result<()> {
3169+
// Create a configuration
3170+
let config = SessionConfig::new();
3171+
let ctx = SessionContext::new_with_config(config);
3172+
3173+
// Create table schema and data
3174+
let sql = "CREATE EXTERNAL TABLE aggregate_test_100 (
3175+
c1 VARCHAR NOT NULL,
3176+
c2 TINYINT NOT NULL,
3177+
c3 SMALLINT NOT NULL,
3178+
c4 SMALLINT,
3179+
c5 INT,
3180+
c6 BIGINT NOT NULL,
3181+
c7 SMALLINT NOT NULL,
3182+
c8 INT NOT NULL,
3183+
c9 BIGINT UNSIGNED NOT NULL,
3184+
c10 VARCHAR NOT NULL,
3185+
c11 FLOAT NOT NULL,
3186+
c12 DOUBLE NOT NULL,
3187+
c13 VARCHAR NOT NULL
3188+
)
3189+
STORED AS CSV
3190+
LOCATION '../../testing/data/csv/aggregate_test_100.csv'
3191+
OPTIONS ('format.has_header' 'true')";
3192+
3193+
ctx.sql(sql).await?;
3194+
3195+
let df = ctx.sql("SELECT * FROM(SELECT * FROM aggregate_test_100 UNION ALL SELECT * FROM aggregate_test_100) ORDER BY c13 LIMIT 5").await?;
3196+
let logical_plan = df.logical_plan().clone();
3197+
let analyzed_logical_plan = ctx.state().analyzer().execute_and_check(
3198+
logical_plan,
3199+
ctx.state().config_options(),
3200+
|_, _| (),
3201+
)?;
3202+
let optimized_logical_plan = ctx.state().optimizer().optimize(
3203+
analyzed_logical_plan,
3204+
&ctx.state(),
3205+
|_, _| (),
3206+
)?;
3207+
3208+
let optimizers: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>> = vec![
3209+
Arc::new(OutputRequirements::new_add_mode()),
3210+
Arc::new(EnforceDistribution::new()),
3211+
Arc::new(EnforceSorting::new()),
3212+
Arc::new(ProjectionPushdown::new()),
3213+
Arc::new(CoalesceBatches::new()),
3214+
Arc::new(EnforceDistribution::new()), // -- Add enforce distribution rule again
3215+
Arc::new(OutputRequirements::new_remove_mode()),
3216+
Arc::new(ProjectionPushdown::new()),
3217+
Arc::new(LimitPushdown::new()),
3218+
Arc::new(SanityCheckPlan::new()),
3219+
];
3220+
3221+
let planner = DefaultPhysicalPlanner::default();
3222+
let session_state = SessionStateBuilder::new()
3223+
.with_config(ctx.copied_config())
3224+
.with_default_features()
3225+
.with_physical_optimizer_rules(optimizers)
3226+
.build();
3227+
let optimized_physical_plan = planner
3228+
.create_physical_plan(&optimized_logical_plan, &session_state)
3229+
.await?;
3230+
3231+
let mut results = optimized_physical_plan
3232+
.execute(0, ctx.task_ctx().clone())
3233+
.unwrap();
3234+
3235+
let batch = results.next().await.unwrap()?;
3236+
// Without the fix of https://github.com/apache/datafusion/pull/14207, the number of rows will be 10
3237+
assert_eq!(batch.num_rows(), 5);
3238+
Ok(())
3239+
}

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
//! according to the configuration), this rule increases partition counts in
2222
//! the physical plan.
2323
24-
use std::fmt::Debug;
24+
use std::fmt;
25+
use std::fmt::{Debug, Display, Formatter};
2526
use std::sync::Arc;
2627

2728
use crate::optimizer::PhysicalOptimizerRule;
@@ -862,7 +863,11 @@ fn add_roundrobin_on_top(
862863

863864
let new_plan = Arc::new(repartition) as _;
864865

865-
Ok(DistributionContext::new(new_plan, true, vec![input]))
866+
Ok(DistributionContext::new(
867+
new_plan,
868+
DistributionData::new(true),
869+
vec![input],
870+
))
866871
} else {
867872
// Partition is not helpful, we already have desired number of partitions.
868873
Ok(input)
@@ -920,7 +925,11 @@ fn add_hash_on_top(
920925
.with_preserve_order();
921926
let plan = Arc::new(repartition) as _;
922927

923-
return Ok(DistributionContext::new(plan, true, vec![input]));
928+
return Ok(DistributionContext::new(
929+
plan,
930+
DistributionData::new(true),
931+
vec![input],
932+
));
924933
}
925934

926935
Ok(input)
@@ -968,7 +977,7 @@ fn add_spm_on_top(
968977
Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _
969978
};
970979

971-
DistributionContext::new(new_plan, true, vec![input])
980+
DistributionContext::new(new_plan, DistributionData::new(true), vec![input])
972981
} else {
973982
input
974983
}
@@ -993,7 +1002,7 @@ fn add_spm_on_top(
9931002
/// ```
9941003
fn remove_dist_changing_operators(
9951004
mut distribution_context: DistributionContext,
996-
) -> Result<(DistributionContext, Option<usize>)> {
1005+
) -> Result<DistributionContext> {
9971006
let mut fetch = None;
9981007
while is_repartition(&distribution_context.plan)
9991008
|| is_coalesce_partitions(&distribution_context.plan)
@@ -1007,10 +1016,11 @@ fn remove_dist_changing_operators(
10071016
// All of above operators have a single child. First child is only child.
10081017
// Remove any distribution changing operators at the beginning:
10091018
distribution_context = distribution_context.children.swap_remove(0);
1019+
distribution_context.data.fetch = fetch;
10101020
// Note that they will be re-inserted later on if necessary or helpful.
10111021
}
10121022

1013-
Ok((distribution_context, fetch))
1023+
Ok(distribution_context)
10141024
}
10151025

10161026
/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable.
@@ -1033,14 +1043,14 @@ fn remove_dist_changing_operators(
10331043
/// ```
10341044
fn replace_order_preserving_variants(
10351045
mut context: DistributionContext,
1036-
) -> Result<(DistributionContext, Option<usize>)> {
1046+
) -> Result<DistributionContext> {
10371047
let mut children = vec![];
10381048
let mut fetch = None;
10391049
for child in context.children.into_iter() {
1040-
if child.data {
1041-
let (child, inner_fetch) = replace_order_preserving_variants(child)?;
1050+
if child.data.has_dist_changing {
1051+
let child = replace_order_preserving_variants(child)?;
1052+
fetch = child.data.fetch;
10421053
children.push(child);
1043-
fetch = inner_fetch;
10441054
} else {
10451055
children.push(child);
10461056
}
@@ -1052,7 +1062,8 @@ fn replace_order_preserving_variants(
10521062
let fetch = context.plan.fetch();
10531063
let child_plan = Arc::clone(&context.children[0].plan);
10541064
context.plan = Arc::new(CoalescePartitionsExec::new(child_plan));
1055-
return Ok((context, fetch));
1065+
context.data.fetch = fetch;
1066+
return Ok(context);
10561067
} else if let Some(repartition) =
10571068
context.plan.as_any().downcast_ref::<RepartitionExec>()
10581069
{
@@ -1061,11 +1072,12 @@ fn replace_order_preserving_variants(
10611072
Arc::clone(&context.children[0].plan),
10621073
repartition.partitioning().clone(),
10631074
)?);
1064-
return Ok((context, None));
1075+
return Ok(context);
10651076
}
10661077
}
10671078

1068-
Ok((context.update_plan_from_children()?, fetch))
1079+
context.data.fetch = fetch;
1080+
context.update_plan_from_children()
10691081
}
10701082

10711083
/// A struct to keep track of repartition requirements for each child node.
@@ -1202,14 +1214,11 @@ pub fn ensure_distribution(
12021214
unbounded_and_pipeline_friendly || config.optimizer.prefer_existing_sort;
12031215

12041216
// Remove unnecessary repartition from the physical plan if any
1205-
let (
1206-
DistributionContext {
1207-
mut plan,
1208-
data,
1209-
children,
1210-
},
1211-
mut fetch,
1212-
) = remove_dist_changing_operators(dist_context)?;
1217+
let DistributionContext {
1218+
mut plan,
1219+
mut data,
1220+
children,
1221+
} = remove_dist_changing_operators(dist_context)?;
12131222

12141223
if let Some(exec) = plan.as_any().downcast_ref::<WindowAggExec>() {
12151224
if let Some(updated_window) = get_best_fitting_window(
@@ -1274,7 +1283,7 @@ pub fn ensure_distribution(
12741283
// Satisfy the distribution requirement if it is unmet.
12751284
match &requirement {
12761285
Distribution::SinglePartition => {
1277-
child = add_spm_on_top(child, &mut fetch);
1286+
child = add_spm_on_top(child, &mut data.fetch);
12781287
}
12791288
Distribution::HashPartitioned(exprs) => {
12801289
if add_roundrobin {
@@ -1307,14 +1316,13 @@ pub fn ensure_distribution(
13071316
.equivalence_properties()
13081317
.ordering_satisfy_requirement(&required_input_ordering);
13091318
if (!ordering_satisfied || !order_preserving_variants_desirable)
1310-
&& child.data
1319+
&& child.data.has_dist_changing
13111320
{
1312-
let (replaced_child, fetch) =
1313-
replace_order_preserving_variants(child)?;
1314-
child = replaced_child;
1321+
child = replace_order_preserving_variants(child)?;
13151322
// If ordering requirements were satisfied before repartitioning,
13161323
// make sure ordering requirements are still satisfied after.
13171324
if ordering_satisfied {
1325+
let fetch = child.data.fetch;
13181326
// Make sure to satisfy ordering requirement:
13191327
child = add_sort_above_with_check(
13201328
child,
@@ -1324,19 +1332,19 @@ pub fn ensure_distribution(
13241332
}
13251333
}
13261334
// Stop tracking distribution changing operators
1327-
child.data = false;
1335+
child.data.has_dist_changing = false;
13281336
} else {
13291337
// no ordering requirement
13301338
match requirement {
13311339
// Operator requires specific distribution.
13321340
Distribution::SinglePartition | Distribution::HashPartitioned(_) => {
13331341
// Since there is no ordering requirement, preserving ordering is pointless
1334-
child = replace_order_preserving_variants(child)?.0;
1342+
child = replace_order_preserving_variants(child)?;
13351343
}
13361344
Distribution::UnspecifiedDistribution => {
13371345
// Since ordering is lost, trying to preserve ordering is pointless
13381346
if !maintains || plan.as_any().is::<OutputRequirementExec>() {
1339-
child = replace_order_preserving_variants(child)?.0;
1347+
child = replace_order_preserving_variants(child)?;
13401348
}
13411349
}
13421350
}
@@ -1386,15 +1394,15 @@ pub fn ensure_distribution(
13861394
// If `fetch` was not consumed, it means that there was `SortPreservingMergeExec` with fetch before
13871395
// It was removed by `remove_dist_changing_operators`
13881396
// and we need to add it back.
1389-
if fetch.is_some() {
1397+
if data.fetch.is_some() {
13901398
plan = Arc::new(
13911399
SortPreservingMergeExec::new(
13921400
plan.output_ordering()
13931401
.unwrap_or(&LexOrdering::default())
13941402
.clone(),
13951403
plan,
13961404
)
1397-
.with_fetch(fetch.take()),
1405+
.with_fetch(data.fetch.take()),
13981406
)
13991407
}
14001408

@@ -1403,16 +1411,44 @@ pub fn ensure_distribution(
14031411
)))
14041412
}
14051413

1414+
/// Distribution context that tracks distribution changing operators and fetch limits
1415+
#[derive(Debug, Clone, Default)]
1416+
pub struct DistributionData {
1417+
/// Whether this node contains distribution changing operators
1418+
pub has_dist_changing: bool,
1419+
/// /// Limit which must be applied to any sort preserving merge that is created
1420+
pub fetch: Option<usize>,
1421+
}
1422+
1423+
impl DistributionData {
1424+
fn new(has_dist_changing: bool) -> Self {
1425+
Self {
1426+
has_dist_changing,
1427+
fetch: None,
1428+
}
1429+
}
1430+
}
1431+
1432+
impl Display for DistributionData {
1433+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
1434+
write!(
1435+
f,
1436+
"(has_dist_changing: {}, fetch: {:?})",
1437+
self.has_dist_changing, self.fetch
1438+
)
1439+
}
1440+
}
1441+
14061442
/// Keeps track of distribution changing operators (like `RepartitionExec`,
14071443
/// `SortPreservingMergeExec`, `CoalescePartitionsExec`) and their ancestors.
14081444
/// Using this information, we can optimize distribution of the plan if/when
14091445
/// necessary.
1410-
pub type DistributionContext = PlanContext<bool>;
1446+
pub type DistributionContext = PlanContext<DistributionData>;
14111447

14121448
fn update_children(mut dist_context: DistributionContext) -> Result<DistributionContext> {
14131449
for child_context in dist_context.children.iter_mut() {
14141450
let child_plan_any = child_context.plan.as_any();
1415-
child_context.data =
1451+
child_context.data.has_dist_changing =
14161452
if let Some(repartition) = child_plan_any.downcast_ref::<RepartitionExec>() {
14171453
!matches!(
14181454
repartition.partitioning(),
@@ -1422,14 +1458,14 @@ fn update_children(mut dist_context: DistributionContext) -> Result<Distribution
14221458
child_plan_any.is::<SortPreservingMergeExec>()
14231459
|| child_plan_any.is::<CoalescePartitionsExec>()
14241460
|| child_context.plan.children().is_empty()
1425-
|| child_context.children[0].data
1461+
|| child_context.children[0].data.has_dist_changing
14261462
|| child_context
14271463
.plan
14281464
.required_input_distribution()
14291465
.iter()
14301466
.zip(child_context.children.iter())
14311467
.any(|(required_dist, child_context)| {
1432-
child_context.data
1468+
child_context.data.has_dist_changing
14331469
&& matches!(
14341470
required_dist,
14351471
Distribution::UnspecifiedDistribution
@@ -1438,7 +1474,7 @@ fn update_children(mut dist_context: DistributionContext) -> Result<Distribution
14381474
}
14391475
}
14401476

1441-
dist_context.data = false;
1477+
dist_context.data.has_dist_changing = false;
14421478
Ok(dist_context)
14431479
}
14441480

0 commit comments

Comments
 (0)