Skip to content

Commit 1bedfec

Browse files
committed
Add TreeNodeMutator API
Use TreeNode API in Optimizer
1 parent 9487ca0 commit 1bedfec

35 files changed

+1049
-536
lines changed

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub fn main() -> Result<()> {
5959

6060
// then run the optimizer with our custom rule
6161
let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]);
62-
let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?;
62+
let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?;
6363
println!(
6464
"Optimized Logical Plan:\n\n{}\n",
6565
optimized_plan.display_indent()

datafusion/common/src/tree_node.rs

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
use std::sync::Arc;
2222

23-
use crate::Result;
23+
use crate::{error::_not_impl_err, Result};
2424

2525
/// This macro is used to control continuation behaviors during tree traversals
2626
/// based on the specified direction. Depending on `$DIRECTION` and the value of
@@ -100,6 +100,10 @@ pub trait TreeNode: Sized {
100100
/// Visit the tree node using the given [`TreeNodeVisitor`], performing a
101101
/// depth-first walk of the node and its children.
102102
///
103+
/// See also:
104+
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
105+
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
106+
///
103107
/// Consider the following tree structure:
104108
/// ```text
105109
/// ParentNode
@@ -144,6 +148,10 @@ pub trait TreeNode: Sized {
144148
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
145149
/// recursively transforming [`TreeNode`]s.
146150
///
151+
/// See also:
152+
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
153+
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
154+
///
147155
/// Consider the following tree structure:
148156
/// ```text
149157
/// ParentNode
@@ -174,6 +182,70 @@ pub trait TreeNode: Sized {
174182
})
175183
}
176184

185+
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
186+
/// recursively mutating / rewriting [`TreeNode`]s in place.
187+
///
188+
/// See also:
189+
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
190+
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
191+
///
192+
/// Consider the following tree structure:
193+
/// ```text
194+
/// ParentNode
195+
/// left: ChildNode1
196+
/// right: ChildNode2
197+
/// ```
198+
///
199+
/// Here, the nodes would be mutataed in the following order:
200+
/// ```text
201+
/// TreeNodeMutator::f_down(ParentNode)
202+
/// TreeNodeMutator::f_down(ChildNode1)
203+
/// TreeNodeMutator::f_up(ChildNode1)
204+
/// TreeNodeMutator::f_down(ChildNode2)
205+
/// TreeNodeMutator::f_up(ChildNode2)
206+
/// TreeNodeMutator::f_up(ParentNode)
207+
/// ```
208+
///
209+
/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
210+
///
211+
/// # Error Handling
212+
///
213+
/// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`],
214+
/// the recursion stops immediately and the tree may be left partially changed
215+
///
216+
/// # Changing Children During Traversal
217+
///
218+
/// If `f_down` changes the nodes children, the new children are visited
219+
/// (not the old children prior to rewrite)
220+
fn mutate<M: TreeNodeMutator<Node = Self>>(
221+
&mut self,
222+
mutator: &mut M,
223+
) -> Result<Transformed<()>> {
224+
// Note this is an inlined version of handle_transform_recursion!
225+
let pre_visited = mutator.f_down(self)?;
226+
227+
// Traverse children and then call f_up on self if necessary
228+
match pre_visited.tnr {
229+
TreeNodeRecursion::Continue => {
230+
// rewrite children recursively with mutator
231+
self.mutate_children(|c| c.mutate(mutator))?
232+
.try_transform_node_with(
233+
|_: ()| mutator.f_up(self),
234+
TreeNodeRecursion::Jump,
235+
)
236+
}
237+
TreeNodeRecursion::Jump => {
238+
// skip other children and start back up
239+
mutator.f_up(self)
240+
}
241+
TreeNodeRecursion::Stop => return Ok(pre_visited),
242+
}
243+
.map(|mut post_visited| {
244+
post_visited.transformed |= pre_visited.transformed;
245+
post_visited
246+
})
247+
}
248+
177249
/// Applies `f` to the node and its children. `f` is applied in a pre-order
178250
/// way, and it is controlled by [`TreeNodeRecursion`], which means result
179251
/// of the `f` on a node can cause an early return.
@@ -353,13 +425,38 @@ pub trait TreeNode: Sized {
353425
}
354426

355427
/// Apply the closure `F` to the node's children.
428+
///
429+
/// See `mutate_children` for rewriting in place
356430
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
357431
&self,
358432
f: &mut F,
359433
) -> Result<TreeNodeRecursion>;
360434

361-
/// Apply transform `F` to the node's children. Note that the transform `F`
362-
/// might have a direction (pre-order or post-order).
435+
/// Rewrite the node's children in place using `F`.
436+
///
437+
/// On error, `self` is left partially rewritten.
438+
///
439+
/// # Notes
440+
///
441+
/// Using [`Self::map_children`], the owned API, has clearer semantics on
442+
/// error (the node is consumed). However, it requires copying the interior
443+
/// fields of the tree node during rewrite.
444+
///
445+
/// This API writes the nodes in place, which can be faster as it avoids
446+
/// copying, but leaves the tree node in an partially rewritten state when
447+
/// an error occurs.
448+
fn mutate_children<F: FnMut(&mut Self) -> Result<Transformed<()>>>(
449+
&mut self,
450+
_f: F,
451+
) -> Result<Transformed<()>> {
452+
_not_impl_err!(
453+
"mutate_children not implemented for {} yet",
454+
std::any::type_name::<Self>()
455+
)
456+
}
457+
458+
/// Apply transform `F` to potentially rewrite the node's children. Note
459+
/// that the transform `F` might have a direction (pre-order or post-order).
363460
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
364461
self,
365462
f: F,
@@ -411,6 +508,41 @@ pub trait TreeNodeRewriter: Sized {
411508
}
412509
}
413510

511+
/// Trait for mutating (rewriting in place) [`TreeNode`]s
512+
///
513+
/// # See Also:
514+
/// * [`TreeNodeRewriter`] for rewriting owned `TreeNode`e
515+
/// * [`TreeNodeVisitor`] for visiting, but not changing, `TreeNode`s
516+
pub trait TreeNodeMutator: Sized {
517+
/// The node type to mutating.
518+
type Node: TreeNode;
519+
520+
/// Invoked while traversing down the tree before any children are mutated.
521+
/// Default implementation does nothing to the node and continues recursion.
522+
///
523+
/// # Notes
524+
///
525+
/// As the node maybe mutated in place, the returned [`Transformed`] object
526+
/// returns `()` (no data).
527+
///
528+
/// If the node's children are changed by `f_down`, the *new* children are
529+
/// visited, not the original children.
530+
fn f_down(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
531+
Ok(Transformed::no(()))
532+
}
533+
534+
/// Invoked while traversing up the tree after all children have been mutated.
535+
/// Default implementation does nothing to the node and continues recursion.
536+
///
537+
/// # Notes
538+
///
539+
/// As the node maybe mutated in place, the returned [`Transformed`] object
540+
/// returns `()` (no data).
541+
fn f_up(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
542+
Ok(Transformed::no(()))
543+
}
544+
}
545+
414546
/// Controls how [`TreeNode`] recursions should proceed.
415547
#[derive(Debug, PartialEq, Clone, Copy)]
416548
pub enum TreeNodeRecursion {
@@ -489,6 +621,11 @@ impl<T> Transformed<T> {
489621
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
490622
}
491623

624+
/// Invokes f(), depending on the value of self.tnr.
625+
///
626+
/// This is used to conditionally apply a function during a f_up tree
627+
/// traversal, if the result of children traversal was `[`TreeNodeRecursion::Continue`].
628+
///
492629
/// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`]
493630
/// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently
494631
/// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of
@@ -532,6 +669,25 @@ impl<T> Transformed<T> {
532669
}
533670
}
534671

672+
impl Transformed<()> {
673+
/// Invoke the given function `f` and combine the transformed state with
674+
/// the current state:
675+
///
676+
/// * if `f` returns an Err, returns that err
677+
///
678+
/// * If `f` returns Ok, sets `self.transformed` to `true` if either self or
679+
/// the result of `f` were transformed.
680+
pub fn and_then<F>(self, f: F) -> Result<Transformed<()>>
681+
where
682+
F: FnOnce() -> Result<Transformed<()>>,
683+
{
684+
f().map(|mut t| {
685+
t.transformed |= self.transformed;
686+
t
687+
})
688+
}
689+
}
690+
535691
/// Transformation helper to process tree nodes that are siblings.
536692
pub trait TransformedIterator: Iterator {
537693
fn map_until_stop_and_collect<

datafusion/core/src/execution/context/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,7 +1877,7 @@ impl SessionState {
18771877

18781878
// optimize the child plan, capturing the output of each optimizer
18791879
let optimized_plan = self.optimizer.optimize(
1880-
&analyzed_plan,
1880+
analyzed_plan,
18811881
self,
18821882
|optimized_plan, optimizer| {
18831883
let optimizer_name = optimizer.name().to_string();
@@ -1907,7 +1907,7 @@ impl SessionState {
19071907
let analyzed_plan =
19081908
self.analyzer
19091909
.execute_and_check(plan, self.options(), |_, _| {})?;
1910-
self.optimizer.optimize(&analyzed_plan, self, |_, _| {})
1910+
self.optimizer.optimize(analyzed_plan, self, |_, _| {})
19111911
}
19121912
}
19131913

datafusion/core/tests/optimizer_integration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
110110
let optimizer = Optimizer::new();
111111
// analyze and optimize the logical plan
112112
let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?;
113-
optimizer.optimize(&plan, &config, |_, _| {})
113+
optimizer.optimize(plan, &config, |_, _| {})
114114
}
115115

116116
#[derive(Default)]

datafusion/expr/src/logical_plan/ddl.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ impl DdlStatement {
112112
}
113113
}
114114

115+
/// Return a mutable reference to the input `LogicalPlan`, if any
116+
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
117+
match self {
118+
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
119+
Some(input)
120+
}
121+
DdlStatement::CreateExternalTable(_) => None,
122+
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
123+
DdlStatement::CreateCatalogSchema(_) => None,
124+
DdlStatement::CreateCatalog(_) => None,
125+
DdlStatement::DropTable(_) => None,
126+
DdlStatement::DropView(_) => None,
127+
DdlStatement::DropCatalogSchema(_) => None,
128+
DdlStatement::CreateFunction(_) => None,
129+
DdlStatement::DropFunction(_) => None,
130+
}
131+
}
132+
115133
/// Return a `format`able structure with the a human readable
116134
/// description of this LogicalPlan node per node, not including
117135
/// children.

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod ddl;
2020
pub mod display;
2121
pub mod dml;
2222
mod extension;
23+
mod mutate;
2324
mod plan;
2425
mod statement;
2526

0 commit comments

Comments
 (0)