Skip to content

Commit

Permalink
Improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
blaginin committed Oct 30, 2024
1 parent 4e1b81e commit e475464
Showing 1 changed file with 70 additions and 68 deletions.
138 changes: 70 additions & 68 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ impl<T> Transformed<T> {
}
}

/// Replaces recursion state with the given one
pub fn with_tnr(mut self, tnr: TreeNodeRecursion) -> Self {
self.tnr = tnr;
self
Expand Down Expand Up @@ -956,7 +957,8 @@ pub trait DynTreeNode {
) -> Result<Arc<Self>>;
}

pub struct LegacyRewriter<
/// Adapter from the old function-based rewriter to the new Transformer one
struct FuncRewriter<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
Expand All @@ -970,7 +972,7 @@ impl<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
> LegacyRewriter<FD, FU, Node>
> FuncRewriter<FD, FU, Node>
{
pub fn new(f_down_func: FD, f_up_func: FU) -> Self {
Self {
Expand All @@ -984,7 +986,7 @@ impl<
FD: FnMut(Node) -> Result<Transformed<Node>>,
FU: FnMut(Node) -> Result<Transformed<Node>>,
Node: TreeNode,
> TreeNodeRewriter for LegacyRewriter<FD, FU, Node>
> TreeNodeRewriter for FuncRewriter<FD, FU, Node>
{
type Node = Node;

Expand All @@ -997,7 +999,8 @@ impl<
}
}

macro_rules! update_rec_node {
/// Replaces node's children and recomputes the state
macro_rules! update_node_after_recursion {
($NAME:ident, $CHILDREN:ident) => {{
$NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed);

Expand All @@ -1011,6 +1014,7 @@ macro_rules! update_rec_node {

/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
/// (such as [`Arc<dyn PhysicalExpr>`]).
/// Unlike [`TreeNode`], performs node traversal iteratively rather than recursively to avoid stack overflow
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
Expand Down Expand Up @@ -1051,41 +1055,36 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
f_down: FD,
f_up: FU,
) -> Result<Transformed<Self>> {
self.rewrite(&mut LegacyRewriter::new(f_down, f_up))
self.rewrite(&mut FuncRewriter::new(f_down, f_up))
}

fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.rewrite(&mut LegacyRewriter::new(f, |node| {
Ok(Transformed::no(node))
}))
self.rewrite(&mut FuncRewriter::new(f, |node| Ok(Transformed::no(node))))
}

fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
) -> Result<Transformed<Self>> {
self.rewrite(&mut LegacyRewriter::new(
|node| Ok(Transformed::no(node)),
f,
))
self.rewrite(&mut FuncRewriter::new(|node| Ok(Transformed::no(node)), f))
}
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
let mut stack = vec![ProcessingState::NotStarted(self)];
let mut stack = vec![TransformingState::NotStarted(self)];

while let Some(item) = stack.pop() {
match item {
ProcessingState::NotStarted(node) => {
TransformingState::NotStarted(node) => {
let node = rewriter.f_down(node)?;

stack.push(match node.tnr {
TreeNodeRecursion::Continue => {
ProcessingState::ProcessingChildren {
TransformingState::ProcessingChildren {
non_processed_children: node
.data
.arc_children()
Expand All @@ -1097,59 +1096,68 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
processed_children: vec![],
}
}
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
node.with_tnr(TreeNodeRecursion::Continue),
),
TreeNodeRecursion::Jump => {
TransformingState::ProcessedAllChildren(
// No need to process children, we can just this stage
node.with_tnr(TreeNodeRecursion::Continue),
)
}
TreeNodeRecursion::Stop => {
ProcessingState::ProcessedAllChildren(node)
TransformingState::ProcessedAllChildren(node)
}
})
}
ProcessingState::ProcessingChildren {
TransformingState::ProcessingChildren {
mut item,
mut non_processed_children,
mut processed_children,
} => match item.tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
if let Some(non_processed_item) = non_processed_children.pop() {
stack.push(ProcessingState::ProcessingChildren {
item,
non_processed_children,
processed_children,
});
stack.push(ProcessingState::NotStarted(non_processed_item));
stack.extend([
// This node still has children, so put it back in the stack
TransformingState::ProcessingChildren {
item,
non_processed_children,
processed_children,
},
// Also put the child which will be processed first
TransformingState::NotStarted(non_processed_item),
]);
} else {
stack.push(ProcessingState::ProcessedAllChildren(
update_rec_node!(item, processed_children),
stack.push(TransformingState::ProcessedAllChildren(
update_node_after_recursion!(item, processed_children),
))
}
}
TreeNodeRecursion::Stop => {
// At this point, we might have some children we haven't yet processed
processed_children.extend(
non_processed_children
.into_iter()
.rev()
.map(Transformed::no),
);
stack.push(ProcessingState::ProcessedAllChildren(
update_rec_node!(item, processed_children),
stack.push(TransformingState::ProcessedAllChildren(
update_node_after_recursion!(item, processed_children),
));
}
},
ProcessingState::ProcessedAllChildren(node) => {
TransformingState::ProcessedAllChildren(node) => {
let node = node.transform_parent(|n| rewriter.f_up(n))?;

if let Some(ProcessingState::ProcessingChildren {
if let Some(TransformingState::ProcessingChildren {
item: mut parent_node,
non_processed_children,
mut processed_children,
..
}) = stack.pop()
{
// We need use returned recursion state when processing the remaining children
parent_node.tnr = node.tnr;
processed_children.push(node);

stack.push(ProcessingState::ProcessingChildren {
stack.push(TransformingState::ProcessingChildren {
item: parent_node,
non_processed_children,
processed_children,
Expand Down Expand Up @@ -1189,10 +1197,7 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
item,
tnr: TreeNodeRecursion::Continue,
},
TreeNodeRecursion::Stop => VisitingState::VisitedAllChildren {
item,
tnr: TreeNodeRecursion::Stop,
},
TreeNodeRecursion::Stop => return Ok(tnr),
});
}
VisitingState::VisitingChildren {
Expand All @@ -1202,12 +1207,15 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
} => match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
if let Some(non_processed_item) = non_processed_children.pop() {
stack.push(VisitingState::VisitingChildren {
item,
non_processed_children,
tnr,
});
stack.push(VisitingState::NotStarted(non_processed_item));
stack.extend([
// Returning the node on the stack because there are more children to process
VisitingState::VisitingChildren {
item,
non_processed_children,
tnr,
},
VisitingState::NotStarted(non_processed_item),
]);
} else {
stack.push(VisitingState::VisitedAllChildren { item, tnr });
}
Expand All @@ -1220,10 +1228,10 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
let tnr = tnr.visit_parent(|| visitor.f_up(item))?;

if let Some(VisitingState::VisitingChildren {
item,
non_processed_children,
..
}) = stack.pop()
item,
non_processed_children,
.. // we don't care about the parent recursion state, because it will be replaced with the current state anyway
}) = stack.pop()
{
stack.push(VisitingState::VisitingChildren {
item,
Expand All @@ -1242,35 +1250,32 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
}
}

#[derive(Debug)]
enum ProcessingState<T> {
/// Node iterative transformation state. Each node on the stack is visited several times, state determines the operation that is going to run next
enum TransformingState<T> {
/// Node was just created. When executed, f_down will be called.
NotStarted(T),
// ← at this point, f_down is called
/// DFS over node's children
ProcessingChildren {
item: Transformed<T>,
non_processed_children: Vec<T>,
processed_children: Vec<Transformed<T>>,
},
// ← at this point, all children are processed
/// All children are processed (or jumped through). When executed, f_up may be called
ProcessedAllChildren(Transformed<T>),
// ← at this point, f_up is called
}

#[derive(Debug)]
/// Node iterative visit state. Each node on the stack is visited several times, state determines the operation that is going to run next
enum VisitingState<'a, T> {
/// Node was just created. When executed, f_down will be called.
NotStarted(&'a T),
// ← at this point, f_down is called
/// DFS over node's children. During processing, reference to children are removed from the inner stack
VisitingChildren {
item: &'a T,
non_processed_children: Vec<&'a T>,
tnr: TreeNodeRecursion,
},
// ← at this point, all children are visited
VisitedAllChildren {
item: &'a T,
tnr: TreeNodeRecursion,
},
// ← at this point, f_up is called
/// All children are processed (or jumped through). When executed, f_up may be called
VisitedAllChildren { item: &'a T, tnr: TreeNodeRecursion },
}

/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
Expand Down Expand Up @@ -1409,11 +1414,8 @@ pub(crate) mod tests {
.collect()
}

/// Implement this train in order for your struct to be supported in parametrisation
pub(crate) trait TestTree<T: PartialEq>
where
T: Sized,
{
/// [`node_tests`] uses methods when generating tests
pub(crate) trait TestTree<T: PartialEq> {
fn new_with_children(children: Vec<Self>, data: T) -> Self
where
Self: Sized;
Expand All @@ -1427,7 +1429,7 @@ pub(crate) mod tests {
fn with_children_from(data: T, other: Self) -> Self;
}

macro_rules! gen_tests {
macro_rules! node_tests {
($TYPE:ident) => {
fn visit_continue<T: PartialEq>(_: &$TYPE<T>) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
Expand Down Expand Up @@ -2546,7 +2548,7 @@ pub(crate) mod tests {
}
}

gen_tests!(TestTreeNode);
node_tests!(TestTreeNode);
}

pub mod test_dyn_tree_node {
Expand Down Expand Up @@ -2607,7 +2609,7 @@ pub(crate) mod tests {

type ArcTestNode<T> = Arc<DynTestNode<T>>;

gen_tests!(ArcTestNode);
node_tests!(ArcTestNode);

#[test]
fn test_large_tree() {
Expand Down

0 comments on commit e475464

Please sign in to comment.