diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1bd36e16afee..fc322d622da5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -779,6 +779,7 @@ impl Transformed { } } + /// Replaces recursion state with the given one pub fn with_tnr(mut self, tnr: TreeNodeRecursion) -> Self { self.tnr = tnr; self @@ -956,7 +957,8 @@ pub trait DynTreeNode { ) -> Result>; } -pub struct LegacyRewriter< +/// Adapter from the old function-based rewriter to the new Transformer one +struct FuncRewriter< FD: FnMut(Node) -> Result>, FU: FnMut(Node) -> Result>, Node: TreeNode, @@ -970,7 +972,7 @@ impl< FD: FnMut(Node) -> Result>, FU: FnMut(Node) -> Result>, Node: TreeNode, - > LegacyRewriter + > FuncRewriter { pub fn new(f_down_func: FD, f_up_func: FU) -> Self { Self { @@ -984,7 +986,7 @@ impl< FD: FnMut(Node) -> Result>, FU: FnMut(Node) -> Result>, Node: TreeNode, - > TreeNodeRewriter for LegacyRewriter + > TreeNodeRewriter for FuncRewriter { type Node = Node; @@ -997,7 +999,8 @@ impl< } } -macro_rules! update_rec_node { +/// Replaces node's children and recomputes the state +macro_rules! update_node_after_recursion { ($NAME:ident, $CHILDREN:ident) => {{ $NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed); @@ -1011,6 +1014,7 @@ macro_rules! update_rec_node { /// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] /// (such as [`Arc`]). +/// Unlike [`TreeNode`], performs node traversal iteratively rather than recursively to avoid stack overflow impl TreeNode for Arc { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, @@ -1051,41 +1055,36 @@ impl TreeNode for Arc { f_down: FD, f_up: FU, ) -> Result> { - self.rewrite(&mut LegacyRewriter::new(f_down, f_up)) + self.rewrite(&mut FuncRewriter::new(f_down, f_up)) } fn transform_down Result>>( self, f: F, ) -> Result> { - self.rewrite(&mut LegacyRewriter::new(f, |node| { - Ok(Transformed::no(node)) - })) + self.rewrite(&mut FuncRewriter::new(f, |node| Ok(Transformed::no(node)))) } fn transform_up Result>>( self, f: F, ) -> Result> { - self.rewrite(&mut LegacyRewriter::new( - |node| Ok(Transformed::no(node)), - f, - )) + self.rewrite(&mut FuncRewriter::new(|node| Ok(Transformed::no(node)), f)) } fn rewrite>( self, rewriter: &mut R, ) -> Result> { - let mut stack = vec![ProcessingState::NotStarted(self)]; + let mut stack = vec![TransformingState::NotStarted(self)]; while let Some(item) = stack.pop() { match item { - ProcessingState::NotStarted(node) => { + TransformingState::NotStarted(node) => { let node = rewriter.f_down(node)?; stack.push(match node.tnr { TreeNodeRecursion::Continue => { - ProcessingState::ProcessingChildren { + TransformingState::ProcessingChildren { non_processed_children: node .data .arc_children() @@ -1097,59 +1096,68 @@ impl TreeNode for Arc { processed_children: vec![], } } - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( - node.with_tnr(TreeNodeRecursion::Continue), - ), + TreeNodeRecursion::Jump => { + TransformingState::ProcessedAllChildren( + // No need to process children, we can just this stage + node.with_tnr(TreeNodeRecursion::Continue), + ) + } TreeNodeRecursion::Stop => { - ProcessingState::ProcessedAllChildren(node) + TransformingState::ProcessedAllChildren(node) } }) } - ProcessingState::ProcessingChildren { + TransformingState::ProcessingChildren { mut item, mut non_processed_children, mut processed_children, } => match item.tnr { TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { if let Some(non_processed_item) = non_processed_children.pop() { - stack.push(ProcessingState::ProcessingChildren { - item, - non_processed_children, - processed_children, - }); - stack.push(ProcessingState::NotStarted(non_processed_item)); + stack.extend([ + // This node still has children, so put it back in the stack + TransformingState::ProcessingChildren { + item, + non_processed_children, + processed_children, + }, + // Also put the child which will be processed first + TransformingState::NotStarted(non_processed_item), + ]); } else { - stack.push(ProcessingState::ProcessedAllChildren( - update_rec_node!(item, processed_children), + stack.push(TransformingState::ProcessedAllChildren( + update_node_after_recursion!(item, processed_children), )) } } TreeNodeRecursion::Stop => { + // At this point, we might have some children we haven't yet processed processed_children.extend( non_processed_children .into_iter() .rev() .map(Transformed::no), ); - stack.push(ProcessingState::ProcessedAllChildren( - update_rec_node!(item, processed_children), + stack.push(TransformingState::ProcessedAllChildren( + update_node_after_recursion!(item, processed_children), )); } }, - ProcessingState::ProcessedAllChildren(node) => { + TransformingState::ProcessedAllChildren(node) => { let node = node.transform_parent(|n| rewriter.f_up(n))?; - if let Some(ProcessingState::ProcessingChildren { + if let Some(TransformingState::ProcessingChildren { item: mut parent_node, non_processed_children, mut processed_children, .. }) = stack.pop() { + // We need use returned recursion state when processing the remaining children parent_node.tnr = node.tnr; processed_children.push(node); - stack.push(ProcessingState::ProcessingChildren { + stack.push(TransformingState::ProcessingChildren { item: parent_node, non_processed_children, processed_children, @@ -1189,10 +1197,7 @@ impl TreeNode for Arc { item, tnr: TreeNodeRecursion::Continue, }, - TreeNodeRecursion::Stop => VisitingState::VisitedAllChildren { - item, - tnr: TreeNodeRecursion::Stop, - }, + TreeNodeRecursion::Stop => return Ok(tnr), }); } VisitingState::VisitingChildren { @@ -1202,12 +1207,15 @@ impl TreeNode for Arc { } => match tnr { TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { if let Some(non_processed_item) = non_processed_children.pop() { - stack.push(VisitingState::VisitingChildren { - item, - non_processed_children, - tnr, - }); - stack.push(VisitingState::NotStarted(non_processed_item)); + stack.extend([ + // Returning the node on the stack because there are more children to process + VisitingState::VisitingChildren { + item, + non_processed_children, + tnr, + }, + VisitingState::NotStarted(non_processed_item), + ]); } else { stack.push(VisitingState::VisitedAllChildren { item, tnr }); } @@ -1220,10 +1228,10 @@ impl TreeNode for Arc { let tnr = tnr.visit_parent(|| visitor.f_up(item))?; if let Some(VisitingState::VisitingChildren { - item, - non_processed_children, - .. - }) = stack.pop() + item, + non_processed_children, + .. // we don't care about the parent recursion state, because it will be replaced with the current state anyway + }) = stack.pop() { stack.push(VisitingState::VisitingChildren { item, @@ -1242,35 +1250,32 @@ impl TreeNode for Arc { } } -#[derive(Debug)] -enum ProcessingState { +/// Node iterative transformation state. Each node on the stack is visited several times, state determines the operation that is going to run next +enum TransformingState { + /// Node was just created. When executed, f_down will be called. NotStarted(T), - // ← at this point, f_down is called + /// DFS over node's children ProcessingChildren { item: Transformed, non_processed_children: Vec, processed_children: Vec>, }, - // ← at this point, all children are processed + /// All children are processed (or jumped through). When executed, f_up may be called ProcessedAllChildren(Transformed), - // ← at this point, f_up is called } -#[derive(Debug)] +/// Node iterative visit state. Each node on the stack is visited several times, state determines the operation that is going to run next enum VisitingState<'a, T> { + /// Node was just created. When executed, f_down will be called. NotStarted(&'a T), - // ← at this point, f_down is called + /// DFS over node's children. During processing, reference to children are removed from the inner stack VisitingChildren { item: &'a T, non_processed_children: Vec<&'a T>, tnr: TreeNodeRecursion, }, - // ← at this point, all children are visited - VisitedAllChildren { - item: &'a T, - tnr: TreeNodeRecursion, - }, - // ← at this point, f_up is called + /// All children are processed (or jumped through). When executed, f_up may be called + VisitedAllChildren { item: &'a T, tnr: TreeNodeRecursion }, } /// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for @@ -1409,11 +1414,8 @@ pub(crate) mod tests { .collect() } - /// Implement this train in order for your struct to be supported in parametrisation - pub(crate) trait TestTree - where - T: Sized, - { + /// [`node_tests`] uses methods when generating tests + pub(crate) trait TestTree { fn new_with_children(children: Vec, data: T) -> Self where Self: Sized; @@ -1427,7 +1429,7 @@ pub(crate) mod tests { fn with_children_from(data: T, other: Self) -> Self; } - macro_rules! gen_tests { + macro_rules! node_tests { ($TYPE:ident) => { fn visit_continue(_: &$TYPE) -> Result { Ok(TreeNodeRecursion::Continue) @@ -2546,7 +2548,7 @@ pub(crate) mod tests { } } - gen_tests!(TestTreeNode); + node_tests!(TestTreeNode); } pub mod test_dyn_tree_node { @@ -2607,7 +2609,7 @@ pub(crate) mod tests { type ArcTestNode = Arc>; - gen_tests!(ArcTestNode); + node_tests!(ArcTestNode); #[test] fn test_large_tree() {