@@ -191,17 +191,15 @@ impl EnforceDistribution {
191191impl PhysicalOptimizerRule for EnforceDistribution {
192192 fn optimize (
193193 & self ,
194- plan : Arc < dyn ExecutionPlan > ,
194+ mut plan : Arc < dyn ExecutionPlan > ,
195195 config : & ConfigOptions ,
196196 ) -> Result < Arc < dyn ExecutionPlan > > {
197197 let top_down_join_key_reordering = config. optimizer . top_down_join_key_reordering ;
198198
199199 let adjusted = if top_down_join_key_reordering {
200200 // Run a top-down process to adjust input key ordering recursively
201- let plan_requirements = PlanWithKeyRequirements :: new ( plan) ;
202- let adjusted =
203- plan_requirements. transform_down_old ( & adjust_input_keys_ordering) ?;
204- adjusted. plan
201+ plan. transform_down_with_payload ( & mut adjust_input_keys_ordering, None ) ?;
202+ plan
205203 } else {
206204 // Run a bottom-up process
207205 plan. transform_up_old ( & |plan| {
@@ -270,11 +268,12 @@ impl PhysicalOptimizerRule for EnforceDistribution {
270268/// 5) For other types of operators, by default, pushdown the parent requirements to children.
271269///
272270fn adjust_input_keys_ordering (
273- requirements : PlanWithKeyRequirements ,
274- ) -> Result < Transformed < PlanWithKeyRequirements > > {
275- let parent_required = requirements. required_key_ordering . clone ( ) ;
276- let plan_any = requirements. plan . as_any ( ) ;
277- let transformed = if let Some ( HashJoinExec {
271+ plan : & mut Arc < dyn ExecutionPlan > ,
272+ required_key_ordering : Option < Vec < Arc < dyn PhysicalExpr > > > ,
273+ ) -> Result < ( TreeNodeRecursion , Vec < Option < Vec < Arc < dyn PhysicalExpr > > > > ) > {
274+ let parent_required = required_key_ordering. unwrap_or_default ( ) . clone ( ) ;
275+ let plan_any = plan. as_any ( ) ;
276+ if let Some ( HashJoinExec {
278277 left,
279278 right,
280279 on,
@@ -299,13 +298,15 @@ fn adjust_input_keys_ordering(
299298 * null_equals_null,
300299 ) ?) as Arc < dyn ExecutionPlan > )
301300 } ;
302- Some ( reorder_partitioned_join_keys (
303- requirements . plan . clone ( ) ,
301+ let ( new_plan , request_key_ordering ) = reorder_partitioned_join_keys (
302+ plan. clone ( ) ,
304303 & parent_required,
305304 on,
306305 vec ! [ ] ,
307306 & join_constructor,
308- ) ?)
307+ ) ?;
308+ * plan = new_plan;
309+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
309310 }
310311 PartitionMode :: CollectLeft => {
311312 let new_right_request = match join_type {
@@ -323,30 +324,28 @@ fn adjust_input_keys_ordering(
323324 } ;
324325
325326 // Push down requirements to the right side
326- Some ( PlanWithKeyRequirements {
327- plan : requirements. plan . clone ( ) ,
328- required_key_ordering : vec ! [ ] ,
329- request_key_ordering : vec ! [ None , new_right_request] ,
330- } )
327+ Ok ( ( TreeNodeRecursion :: Continue , vec ! [ None , new_right_request] ) )
331328 }
332329 PartitionMode :: Auto => {
333330 // Can not satisfy, clear the current requirements and generate new empty requirements
334- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
331+ Ok ( (
332+ TreeNodeRecursion :: Continue ,
333+ vec ! [ None ; plan. children( ) . len( ) ] ,
334+ ) )
335335 }
336336 }
337337 } else if let Some ( CrossJoinExec { left, .. } ) =
338338 plan_any. downcast_ref :: < CrossJoinExec > ( )
339339 {
340340 let left_columns_len = left. schema ( ) . fields ( ) . len ( ) ;
341341 // Push down requirements to the right side
342- Some ( PlanWithKeyRequirements {
343- plan : requirements. plan . clone ( ) ,
344- required_key_ordering : vec ! [ ] ,
345- request_key_ordering : vec ! [
342+ Ok ( (
343+ TreeNodeRecursion :: Continue ,
344+ vec ! [
346345 None ,
347346 shift_right_required( & parent_required, left_columns_len) ,
348347 ] ,
349- } )
348+ ) )
350349 } else if let Some ( SortMergeJoinExec {
351350 left,
352351 right,
@@ -368,26 +367,38 @@ fn adjust_input_keys_ordering(
368367 * null_equals_null,
369368 ) ?) as Arc < dyn ExecutionPlan > )
370369 } ;
371- Some ( reorder_partitioned_join_keys (
372- requirements . plan . clone ( ) ,
370+ let ( new_plan , request_key_ordering ) = reorder_partitioned_join_keys (
371+ plan. clone ( ) ,
373372 & parent_required,
374373 on,
375374 sort_options. clone ( ) ,
376375 & join_constructor,
377- ) ?)
376+ ) ?;
377+ * plan = new_plan;
378+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
378379 } else if let Some ( aggregate_exec) = plan_any. downcast_ref :: < AggregateExec > ( ) {
379380 if !parent_required. is_empty ( ) {
380381 match aggregate_exec. mode ( ) {
381- AggregateMode :: FinalPartitioned => Some ( reorder_aggregate_keys (
382- requirements. plan . clone ( ) ,
383- & parent_required,
384- aggregate_exec,
385- ) ?) ,
386- _ => Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) ) ,
382+ AggregateMode :: FinalPartitioned => {
383+ let ( new_plan, request_key_ordering) = reorder_aggregate_keys (
384+ plan. clone ( ) ,
385+ & parent_required,
386+ aggregate_exec,
387+ ) ?;
388+ * plan = new_plan;
389+ Ok ( ( TreeNodeRecursion :: Continue , request_key_ordering) )
390+ }
391+ _ => Ok ( (
392+ TreeNodeRecursion :: Continue ,
393+ vec ! [ None ; plan. children( ) . len( ) ] ,
394+ ) ) ,
387395 }
388396 } else {
389397 // Keep everything unchanged
390- None
398+ Ok ( (
399+ TreeNodeRecursion :: Continue ,
400+ vec ! [ None ; plan. children( ) . len( ) ] ,
401+ ) )
391402 }
392403 } else if let Some ( proj) = plan_any. downcast_ref :: < ProjectionExec > ( ) {
393404 let expr = proj. expr ( ) ;
@@ -396,34 +407,33 @@ fn adjust_input_keys_ordering(
396407 // Construct a mapping from new name to the the orginal Column
397408 let new_required = map_columns_before_projection ( & parent_required, expr) ;
398409 if new_required. len ( ) == parent_required. len ( ) {
399- Some ( PlanWithKeyRequirements {
400- plan : requirements. plan . clone ( ) ,
401- required_key_ordering : vec ! [ ] ,
402- request_key_ordering : vec ! [ Some ( new_required. clone( ) ) ] ,
403- } )
410+ Ok ( (
411+ TreeNodeRecursion :: Continue ,
412+ vec ! [ Some ( new_required. clone( ) ) ] ,
413+ ) )
404414 } else {
405415 // Can not satisfy, clear the current requirements and generate new empty requirements
406- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
416+ Ok ( (
417+ TreeNodeRecursion :: Continue ,
418+ vec ! [ None ; plan. children( ) . len( ) ] ,
419+ ) )
407420 }
408421 } else if plan_any. downcast_ref :: < RepartitionExec > ( ) . is_some ( )
409422 || plan_any. downcast_ref :: < CoalescePartitionsExec > ( ) . is_some ( )
410423 || plan_any. downcast_ref :: < WindowAggExec > ( ) . is_some ( )
411424 {
412- Some ( PlanWithKeyRequirements :: new ( requirements. plan . clone ( ) ) )
425+ Ok ( (
426+ TreeNodeRecursion :: Continue ,
427+ vec ! [ None ; plan. children( ) . len( ) ] ,
428+ ) )
413429 } else {
414430 // By default, push down the parent requirements to children
415- let children_len = requirements. plan . children ( ) . len ( ) ;
416- Some ( PlanWithKeyRequirements {
417- plan : requirements. plan . clone ( ) ,
418- required_key_ordering : vec ! [ ] ,
419- request_key_ordering : vec ! [ Some ( parent_required. clone( ) ) ; children_len] ,
420- } )
421- } ;
422- Ok ( if let Some ( transformed) = transformed {
423- Transformed :: Yes ( transformed)
424- } else {
425- Transformed :: No ( requirements)
426- } )
431+ let children_len = plan. children ( ) . len ( ) ;
432+ Ok ( (
433+ TreeNodeRecursion :: Continue ,
434+ vec ! [ Some ( parent_required. clone( ) ) ; children_len] ,
435+ ) )
436+ }
427437}
428438
429439fn reorder_partitioned_join_keys < F > (
@@ -432,7 +442,10 @@ fn reorder_partitioned_join_keys<F>(
432442 on : & [ ( Column , Column ) ] ,
433443 sort_options : Vec < SortOptions > ,
434444 join_constructor : & F ,
435- ) -> Result < PlanWithKeyRequirements >
445+ ) -> Result < (
446+ Arc < dyn ExecutionPlan > ,
447+ Vec < Option < Vec < Arc < dyn PhysicalExpr > > > > ,
448+ ) >
436449where
437450 F : Fn ( ( Vec < ( Column , Column ) > , Vec < SortOptions > ) ) -> Result < Arc < dyn ExecutionPlan > > ,
438451{
@@ -455,35 +468,32 @@ where
455468 new_sort_options. push ( sort_options[ new_positions[ idx] ] )
456469 }
457470
458- Ok ( PlanWithKeyRequirements {
459- plan : join_constructor ( ( new_join_on, new_sort_options) ) ?,
460- required_key_ordering : vec ! [ ] ,
461- request_key_ordering : vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
462- } )
471+ Ok ( (
472+ join_constructor ( ( new_join_on, new_sort_options) ) ?,
473+ vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
474+ ) )
463475 } else {
464- Ok ( PlanWithKeyRequirements {
465- plan : join_plan,
466- required_key_ordering : vec ! [ ] ,
467- request_key_ordering : vec ! [ Some ( left_keys) , Some ( right_keys) ] ,
468- } )
476+ Ok ( ( join_plan, vec ! [ Some ( left_keys) , Some ( right_keys) ] ) )
469477 }
470478 } else {
471- Ok ( PlanWithKeyRequirements {
472- plan : join_plan,
473- required_key_ordering : vec ! [ ] ,
474- request_key_ordering : vec ! [
479+ Ok ( (
480+ join_plan,
481+ vec ! [
475482 Some ( join_key_pairs. left_keys) ,
476483 Some ( join_key_pairs. right_keys) ,
477484 ] ,
478- } )
485+ ) )
479486 }
480487}
481488
482489fn reorder_aggregate_keys (
483490 agg_plan : Arc < dyn ExecutionPlan > ,
484491 parent_required : & [ Arc < dyn PhysicalExpr > ] ,
485492 agg_exec : & AggregateExec ,
486- ) -> Result < PlanWithKeyRequirements > {
493+ ) -> Result < (
494+ Arc < dyn ExecutionPlan > ,
495+ Vec < Option < Vec < Arc < dyn PhysicalExpr > > > > ,
496+ ) > {
487497 let output_columns = agg_exec
488498 . group_by ( )
489499 . expr ( )
@@ -501,11 +511,15 @@ fn reorder_aggregate_keys(
501511 || !agg_exec. group_by ( ) . null_expr ( ) . is_empty ( )
502512 || physical_exprs_equal ( & output_exprs, parent_required)
503513 {
504- Ok ( PlanWithKeyRequirements :: new ( agg_plan) )
514+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
515+ Ok ( ( agg_plan, request_key_ordering) )
505516 } else {
506517 let new_positions = expected_expr_positions ( & output_exprs, parent_required) ;
507518 match new_positions {
508- None => Ok ( PlanWithKeyRequirements :: new ( agg_plan) ) ,
519+ None => {
520+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
521+ Ok ( ( agg_plan, request_key_ordering) )
522+ }
509523 Some ( positions) => {
510524 let new_partial_agg = if let Some ( agg_exec) =
511525 agg_exec. input ( ) . as_any ( ) . downcast_ref :: < AggregateExec > ( )
@@ -577,11 +591,13 @@ fn reorder_aggregate_keys(
577591 . push ( ( Arc :: new ( Column :: new ( name, idx) ) as _ , name. clone ( ) ) )
578592 }
579593 // TODO merge adjacent Projections if there are
580- Ok ( PlanWithKeyRequirements :: new ( Arc :: new (
581- ProjectionExec :: try_new ( proj_exprs, new_final_agg) ?,
582- ) ) )
594+ let new_plan =
595+ Arc :: new ( ProjectionExec :: try_new ( proj_exprs, new_final_agg) ?) ;
596+ let request_key_ordering = vec ! [ None ; new_plan. children( ) . len( ) ] ;
597+ Ok ( ( new_plan, request_key_ordering) )
583598 } else {
584- Ok ( PlanWithKeyRequirements :: new ( agg_plan) )
599+ let request_key_ordering = vec ! [ None ; agg_plan. children( ) . len( ) ] ;
600+ Ok ( ( agg_plan, request_key_ordering) )
585601 }
586602 }
587603 }
@@ -1539,93 +1555,6 @@ struct JoinKeyPairs {
15391555 right_keys : Vec < Arc < dyn PhysicalExpr > > ,
15401556}
15411557
1542- #[ derive( Debug , Clone ) ]
1543- struct PlanWithKeyRequirements {
1544- plan : Arc < dyn ExecutionPlan > ,
1545- /// Parent required key ordering
1546- required_key_ordering : Vec < Arc < dyn PhysicalExpr > > ,
1547- /// The request key ordering to children
1548- request_key_ordering : Vec < Option < Vec < Arc < dyn PhysicalExpr > > > > ,
1549- }
1550-
1551- impl PlanWithKeyRequirements {
1552- fn new ( plan : Arc < dyn ExecutionPlan > ) -> Self {
1553- let children_len = plan. children ( ) . len ( ) ;
1554- PlanWithKeyRequirements {
1555- plan,
1556- required_key_ordering : vec ! [ ] ,
1557- request_key_ordering : vec ! [ None ; children_len] ,
1558- }
1559- }
1560-
1561- fn children ( & self ) -> Vec < PlanWithKeyRequirements > {
1562- let plan_children = self . plan . children ( ) ;
1563- assert_eq ! ( plan_children. len( ) , self . request_key_ordering. len( ) ) ;
1564- plan_children
1565- . into_iter ( )
1566- . zip ( self . request_key_ordering . clone ( ) )
1567- . map ( |( child, required) | {
1568- let from_parent = required. unwrap_or_default ( ) ;
1569- let length = child. children ( ) . len ( ) ;
1570- PlanWithKeyRequirements {
1571- plan : child,
1572- required_key_ordering : from_parent,
1573- request_key_ordering : vec ! [ None ; length] ,
1574- }
1575- } )
1576- . collect ( )
1577- }
1578- }
1579-
1580- impl TreeNode for PlanWithKeyRequirements {
1581- fn apply_children < F > ( & self , f : & mut F ) -> Result < TreeNodeRecursion >
1582- where
1583- F : FnMut ( & Self ) -> Result < TreeNodeRecursion > ,
1584- {
1585- self . children ( ) . iter ( ) . for_each_till_continue ( f)
1586- }
1587-
1588- fn map_children < F > ( self , transform : F ) -> Result < Self >
1589- where
1590- F : FnMut ( Self ) -> Result < Self > ,
1591- {
1592- let children = self . children ( ) ;
1593- if !children. is_empty ( ) {
1594- let new_children: Result < Vec < _ > > =
1595- children. into_iter ( ) . map ( transform) . collect ( ) ;
1596-
1597- let children_plans = new_children?
1598- . into_iter ( )
1599- . map ( |child| child. plan )
1600- . collect :: < Vec < _ > > ( ) ;
1601- let new_plan = with_new_children_if_necessary ( self . plan , children_plans) ?;
1602- Ok ( PlanWithKeyRequirements {
1603- plan : new_plan. into ( ) ,
1604- required_key_ordering : self . required_key_ordering ,
1605- request_key_ordering : self . request_key_ordering ,
1606- } )
1607- } else {
1608- Ok ( self )
1609- }
1610- }
1611-
1612- fn transform_children < F > ( & mut self , f : & mut F ) -> Result < TreeNodeRecursion >
1613- where
1614- F : FnMut ( & mut Self ) -> Result < TreeNodeRecursion > ,
1615- {
1616- let mut children = self . children ( ) ;
1617- if !children. is_empty ( ) {
1618- let tnr = children. iter_mut ( ) . for_each_till_continue ( f) ?;
1619- let children_plans = children. into_iter ( ) . map ( |c| c. plan ) . collect ( ) ;
1620- self . plan =
1621- with_new_children_if_necessary ( self . plan . clone ( ) , children_plans) ?. into ( ) ;
1622- Ok ( tnr)
1623- } else {
1624- Ok ( TreeNodeRecursion :: Continue )
1625- }
1626- }
1627- }
1628-
16291558/// Since almost all of these tests explicitly use `ParquetExec` they only run with the parquet feature flag on
16301559#[ cfg( feature = "parquet" ) ]
16311560#[ cfg( test) ]
0 commit comments