Skip to content

Commit 09bbb6b

Browse files
committed
- refactor EnforceDistribution using transform_down_with_payload()
1 parent 5c61470 commit 09bbb6b

File tree

1 file changed

+92
-163
lines changed

1 file changed

+92
-163
lines changed

datafusion/core/src/physical_optimizer/enforce_distribution.rs

Lines changed: 92 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -191,17 +191,15 @@ impl EnforceDistribution {
191191
impl PhysicalOptimizerRule for EnforceDistribution {
192192
fn optimize(
193193
&self,
194-
plan: Arc<dyn ExecutionPlan>,
194+
mut plan: Arc<dyn ExecutionPlan>,
195195
config: &ConfigOptions,
196196
) -> Result<Arc<dyn ExecutionPlan>> {
197197
let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering;
198198

199199
let adjusted = if top_down_join_key_reordering {
200200
// Run a top-down process to adjust input key ordering recursively
201-
let plan_requirements = PlanWithKeyRequirements::new(plan);
202-
let adjusted =
203-
plan_requirements.transform_down_old(&adjust_input_keys_ordering)?;
204-
adjusted.plan
201+
plan.transform_down_with_payload(&mut adjust_input_keys_ordering, None)?;
202+
plan
205203
} else {
206204
// Run a bottom-up process
207205
plan.transform_up_old(&|plan| {
@@ -270,11 +268,12 @@ impl PhysicalOptimizerRule for EnforceDistribution {
270268
/// 5) For other types of operators, by default, pushdown the parent requirements to children.
271269
///
272270
fn adjust_input_keys_ordering(
273-
requirements: PlanWithKeyRequirements,
274-
) -> Result<Transformed<PlanWithKeyRequirements>> {
275-
let parent_required = requirements.required_key_ordering.clone();
276-
let plan_any = requirements.plan.as_any();
277-
let transformed = if let Some(HashJoinExec {
271+
plan: &mut Arc<dyn ExecutionPlan>,
272+
required_key_ordering: Option<Vec<Arc<dyn PhysicalExpr>>>,
273+
) -> Result<(TreeNodeRecursion, Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>)> {
274+
let parent_required = required_key_ordering.unwrap_or_default().clone();
275+
let plan_any = plan.as_any();
276+
if let Some(HashJoinExec {
278277
left,
279278
right,
280279
on,
@@ -299,13 +298,15 @@ fn adjust_input_keys_ordering(
299298
*null_equals_null,
300299
)?) as Arc<dyn ExecutionPlan>)
301300
};
302-
Some(reorder_partitioned_join_keys(
303-
requirements.plan.clone(),
301+
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
302+
plan.clone(),
304303
&parent_required,
305304
on,
306305
vec![],
307306
&join_constructor,
308-
)?)
307+
)?;
308+
*plan = new_plan;
309+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
309310
}
310311
PartitionMode::CollectLeft => {
311312
let new_right_request = match join_type {
@@ -323,30 +324,28 @@ fn adjust_input_keys_ordering(
323324
};
324325

325326
// Push down requirements to the right side
326-
Some(PlanWithKeyRequirements {
327-
plan: requirements.plan.clone(),
328-
required_key_ordering: vec![],
329-
request_key_ordering: vec![None, new_right_request],
330-
})
327+
Ok((TreeNodeRecursion::Continue, vec![None, new_right_request]))
331328
}
332329
PartitionMode::Auto => {
333330
// Can not satisfy, clear the current requirements and generate new empty requirements
334-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
331+
Ok((
332+
TreeNodeRecursion::Continue,
333+
vec![None; plan.children().len()],
334+
))
335335
}
336336
}
337337
} else if let Some(CrossJoinExec { left, .. }) =
338338
plan_any.downcast_ref::<CrossJoinExec>()
339339
{
340340
let left_columns_len = left.schema().fields().len();
341341
// Push down requirements to the right side
342-
Some(PlanWithKeyRequirements {
343-
plan: requirements.plan.clone(),
344-
required_key_ordering: vec![],
345-
request_key_ordering: vec![
342+
Ok((
343+
TreeNodeRecursion::Continue,
344+
vec![
346345
None,
347346
shift_right_required(&parent_required, left_columns_len),
348347
],
349-
})
348+
))
350349
} else if let Some(SortMergeJoinExec {
351350
left,
352351
right,
@@ -368,26 +367,38 @@ fn adjust_input_keys_ordering(
368367
*null_equals_null,
369368
)?) as Arc<dyn ExecutionPlan>)
370369
};
371-
Some(reorder_partitioned_join_keys(
372-
requirements.plan.clone(),
370+
let (new_plan, request_key_ordering) = reorder_partitioned_join_keys(
371+
plan.clone(),
373372
&parent_required,
374373
on,
375374
sort_options.clone(),
376375
&join_constructor,
377-
)?)
376+
)?;
377+
*plan = new_plan;
378+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
378379
} else if let Some(aggregate_exec) = plan_any.downcast_ref::<AggregateExec>() {
379380
if !parent_required.is_empty() {
380381
match aggregate_exec.mode() {
381-
AggregateMode::FinalPartitioned => Some(reorder_aggregate_keys(
382-
requirements.plan.clone(),
383-
&parent_required,
384-
aggregate_exec,
385-
)?),
386-
_ => Some(PlanWithKeyRequirements::new(requirements.plan.clone())),
382+
AggregateMode::FinalPartitioned => {
383+
let (new_plan, request_key_ordering) = reorder_aggregate_keys(
384+
plan.clone(),
385+
&parent_required,
386+
aggregate_exec,
387+
)?;
388+
*plan = new_plan;
389+
Ok((TreeNodeRecursion::Continue, request_key_ordering))
390+
}
391+
_ => Ok((
392+
TreeNodeRecursion::Continue,
393+
vec![None; plan.children().len()],
394+
)),
387395
}
388396
} else {
389397
// Keep everything unchanged
390-
None
398+
Ok((
399+
TreeNodeRecursion::Continue,
400+
vec![None; plan.children().len()],
401+
))
391402
}
392403
} else if let Some(proj) = plan_any.downcast_ref::<ProjectionExec>() {
393404
let expr = proj.expr();
@@ -396,34 +407,33 @@ fn adjust_input_keys_ordering(
396407
// Construct a mapping from new name to the the orginal Column
397408
let new_required = map_columns_before_projection(&parent_required, expr);
398409
if new_required.len() == parent_required.len() {
399-
Some(PlanWithKeyRequirements {
400-
plan: requirements.plan.clone(),
401-
required_key_ordering: vec![],
402-
request_key_ordering: vec![Some(new_required.clone())],
403-
})
410+
Ok((
411+
TreeNodeRecursion::Continue,
412+
vec![Some(new_required.clone())],
413+
))
404414
} else {
405415
// Can not satisfy, clear the current requirements and generate new empty requirements
406-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
416+
Ok((
417+
TreeNodeRecursion::Continue,
418+
vec![None; plan.children().len()],
419+
))
407420
}
408421
} else if plan_any.downcast_ref::<RepartitionExec>().is_some()
409422
|| plan_any.downcast_ref::<CoalescePartitionsExec>().is_some()
410423
|| plan_any.downcast_ref::<WindowAggExec>().is_some()
411424
{
412-
Some(PlanWithKeyRequirements::new(requirements.plan.clone()))
425+
Ok((
426+
TreeNodeRecursion::Continue,
427+
vec![None; plan.children().len()],
428+
))
413429
} else {
414430
// By default, push down the parent requirements to children
415-
let children_len = requirements.plan.children().len();
416-
Some(PlanWithKeyRequirements {
417-
plan: requirements.plan.clone(),
418-
required_key_ordering: vec![],
419-
request_key_ordering: vec![Some(parent_required.clone()); children_len],
420-
})
421-
};
422-
Ok(if let Some(transformed) = transformed {
423-
Transformed::Yes(transformed)
424-
} else {
425-
Transformed::No(requirements)
426-
})
431+
let children_len = plan.children().len();
432+
Ok((
433+
TreeNodeRecursion::Continue,
434+
vec![Some(parent_required.clone()); children_len],
435+
))
436+
}
427437
}
428438

429439
fn reorder_partitioned_join_keys<F>(
@@ -432,7 +442,10 @@ fn reorder_partitioned_join_keys<F>(
432442
on: &[(Column, Column)],
433443
sort_options: Vec<SortOptions>,
434444
join_constructor: &F,
435-
) -> Result<PlanWithKeyRequirements>
445+
) -> Result<(
446+
Arc<dyn ExecutionPlan>,
447+
Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>,
448+
)>
436449
where
437450
F: Fn((Vec<(Column, Column)>, Vec<SortOptions>)) -> Result<Arc<dyn ExecutionPlan>>,
438451
{
@@ -455,35 +468,32 @@ where
455468
new_sort_options.push(sort_options[new_positions[idx]])
456469
}
457470

458-
Ok(PlanWithKeyRequirements {
459-
plan: join_constructor((new_join_on, new_sort_options))?,
460-
required_key_ordering: vec![],
461-
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
462-
})
471+
Ok((
472+
join_constructor((new_join_on, new_sort_options))?,
473+
vec![Some(left_keys), Some(right_keys)],
474+
))
463475
} else {
464-
Ok(PlanWithKeyRequirements {
465-
plan: join_plan,
466-
required_key_ordering: vec![],
467-
request_key_ordering: vec![Some(left_keys), Some(right_keys)],
468-
})
476+
Ok((join_plan, vec![Some(left_keys), Some(right_keys)]))
469477
}
470478
} else {
471-
Ok(PlanWithKeyRequirements {
472-
plan: join_plan,
473-
required_key_ordering: vec![],
474-
request_key_ordering: vec![
479+
Ok((
480+
join_plan,
481+
vec![
475482
Some(join_key_pairs.left_keys),
476483
Some(join_key_pairs.right_keys),
477484
],
478-
})
485+
))
479486
}
480487
}
481488

482489
fn reorder_aggregate_keys(
483490
agg_plan: Arc<dyn ExecutionPlan>,
484491
parent_required: &[Arc<dyn PhysicalExpr>],
485492
agg_exec: &AggregateExec,
486-
) -> Result<PlanWithKeyRequirements> {
493+
) -> Result<(
494+
Arc<dyn ExecutionPlan>,
495+
Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>,
496+
)> {
487497
let output_columns = agg_exec
488498
.group_by()
489499
.expr()
@@ -501,11 +511,15 @@ fn reorder_aggregate_keys(
501511
|| !agg_exec.group_by().null_expr().is_empty()
502512
|| physical_exprs_equal(&output_exprs, parent_required)
503513
{
504-
Ok(PlanWithKeyRequirements::new(agg_plan))
514+
let request_key_ordering = vec![None; agg_plan.children().len()];
515+
Ok((agg_plan, request_key_ordering))
505516
} else {
506517
let new_positions = expected_expr_positions(&output_exprs, parent_required);
507518
match new_positions {
508-
None => Ok(PlanWithKeyRequirements::new(agg_plan)),
519+
None => {
520+
let request_key_ordering = vec![None; agg_plan.children().len()];
521+
Ok((agg_plan, request_key_ordering))
522+
}
509523
Some(positions) => {
510524
let new_partial_agg = if let Some(agg_exec) =
511525
agg_exec.input().as_any().downcast_ref::<AggregateExec>()
@@ -577,11 +591,13 @@ fn reorder_aggregate_keys(
577591
.push((Arc::new(Column::new(name, idx)) as _, name.clone()))
578592
}
579593
// TODO merge adjacent Projections if there are
580-
Ok(PlanWithKeyRequirements::new(Arc::new(
581-
ProjectionExec::try_new(proj_exprs, new_final_agg)?,
582-
)))
594+
let new_plan =
595+
Arc::new(ProjectionExec::try_new(proj_exprs, new_final_agg)?);
596+
let request_key_ordering = vec![None; new_plan.children().len()];
597+
Ok((new_plan, request_key_ordering))
583598
} else {
584-
Ok(PlanWithKeyRequirements::new(agg_plan))
599+
let request_key_ordering = vec![None; agg_plan.children().len()];
600+
Ok((agg_plan, request_key_ordering))
585601
}
586602
}
587603
}
@@ -1539,93 +1555,6 @@ struct JoinKeyPairs {
15391555
right_keys: Vec<Arc<dyn PhysicalExpr>>,
15401556
}
15411557

1542-
#[derive(Debug, Clone)]
1543-
struct PlanWithKeyRequirements {
1544-
plan: Arc<dyn ExecutionPlan>,
1545-
/// Parent required key ordering
1546-
required_key_ordering: Vec<Arc<dyn PhysicalExpr>>,
1547-
/// The request key ordering to children
1548-
request_key_ordering: Vec<Option<Vec<Arc<dyn PhysicalExpr>>>>,
1549-
}
1550-
1551-
impl PlanWithKeyRequirements {
1552-
fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
1553-
let children_len = plan.children().len();
1554-
PlanWithKeyRequirements {
1555-
plan,
1556-
required_key_ordering: vec![],
1557-
request_key_ordering: vec![None; children_len],
1558-
}
1559-
}
1560-
1561-
fn children(&self) -> Vec<PlanWithKeyRequirements> {
1562-
let plan_children = self.plan.children();
1563-
assert_eq!(plan_children.len(), self.request_key_ordering.len());
1564-
plan_children
1565-
.into_iter()
1566-
.zip(self.request_key_ordering.clone())
1567-
.map(|(child, required)| {
1568-
let from_parent = required.unwrap_or_default();
1569-
let length = child.children().len();
1570-
PlanWithKeyRequirements {
1571-
plan: child,
1572-
required_key_ordering: from_parent,
1573-
request_key_ordering: vec![None; length],
1574-
}
1575-
})
1576-
.collect()
1577-
}
1578-
}
1579-
1580-
impl TreeNode for PlanWithKeyRequirements {
1581-
fn apply_children<F>(&self, f: &mut F) -> Result<TreeNodeRecursion>
1582-
where
1583-
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
1584-
{
1585-
self.children().iter().for_each_till_continue(f)
1586-
}
1587-
1588-
fn map_children<F>(self, transform: F) -> Result<Self>
1589-
where
1590-
F: FnMut(Self) -> Result<Self>,
1591-
{
1592-
let children = self.children();
1593-
if !children.is_empty() {
1594-
let new_children: Result<Vec<_>> =
1595-
children.into_iter().map(transform).collect();
1596-
1597-
let children_plans = new_children?
1598-
.into_iter()
1599-
.map(|child| child.plan)
1600-
.collect::<Vec<_>>();
1601-
let new_plan = with_new_children_if_necessary(self.plan, children_plans)?;
1602-
Ok(PlanWithKeyRequirements {
1603-
plan: new_plan.into(),
1604-
required_key_ordering: self.required_key_ordering,
1605-
request_key_ordering: self.request_key_ordering,
1606-
})
1607-
} else {
1608-
Ok(self)
1609-
}
1610-
}
1611-
1612-
fn transform_children<F>(&mut self, f: &mut F) -> Result<TreeNodeRecursion>
1613-
where
1614-
F: FnMut(&mut Self) -> Result<TreeNodeRecursion>,
1615-
{
1616-
let mut children = self.children();
1617-
if !children.is_empty() {
1618-
let tnr = children.iter_mut().for_each_till_continue(f)?;
1619-
let children_plans = children.into_iter().map(|c| c.plan).collect();
1620-
self.plan =
1621-
with_new_children_if_necessary(self.plan.clone(), children_plans)?.into();
1622-
Ok(tnr)
1623-
} else {
1624-
Ok(TreeNodeRecursion::Continue)
1625-
}
1626-
}
1627-
}
1628-
16291558
/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
16301559
#[cfg(feature = "parquet")]
16311560
#[cfg(test)]

0 commit comments

Comments
 (0)