Skip to content

Commit df225e9

Browse files
committed
feat(formatter): add AstNode::ancestor and AstNode::grand_parent methods (#14700)
Added a `AstNode::ancestor` to simplify the traverse upward pattern, and added a `AstNode::grand_parent` to simplify `self.parent.parent()` usages
1 parent bae5f11 commit df225e9

File tree

16 files changed

+138
-69
lines changed

16 files changed

+138
-69
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//! Implementations of methods for [`AstNodes`].
2+
3+
use crate::ast_nodes::AstNodes;
4+
5+
impl<'a> AstNodes<'a> {
6+
/// Returns an iterator over all ancestor nodes in the AST, starting from self.
7+
///
8+
/// The iteration includes the current node and proceeds upward through the tree,
9+
/// terminating after yielding the root `Program` node.
10+
///
11+
/// # Example hierarchy
12+
/// ```text
13+
/// Program
14+
/// └─ BlockStatement
15+
/// └─ ExpressionStatement <- self
16+
/// ```
17+
/// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program]
18+
pub fn ancestors(&self) -> impl Iterator<Item = &AstNodes<'a>> {
19+
// Start with the current node and walk up the tree, including Program
20+
std::iter::successors(Some(self), |node| {
21+
// Continue iteration until we've yielded Program (root node)
22+
// After Program, parent() would still return Program, so stop there
23+
if matches!(node, AstNodes::Program(_)) { None } else { Some(node.parent()) }
24+
})
25+
}
26+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod ast_nodes;

crates/oxc_formatter/src/ast_nodes/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod generated;
2+
pub mod impls;
23
mod iterator;
34
mod node;
45

crates/oxc_formatter/src/ast_nodes/node.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,48 @@ impl<T: GetSpan> GetSpan for &AstNode<'_, T> {
6363
}
6464
}
6565

66+
impl<T> AstNode<'_, T> {
67+
/// Returns an iterator over all ancestor nodes in the AST, starting from self.
68+
///
69+
/// The iteration includes the current node and proceeds upward through the tree,
70+
/// terminating after yielding the root `Program` node.
71+
///
72+
/// This is a convenience method that delegates to `self.parent.ancestors()`.
73+
///
74+
/// # Example
75+
/// ```text
76+
/// Program
77+
/// └─ BlockStatement
78+
/// └─ ExpressionStatement <- self
79+
/// ```
80+
/// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program]
81+
///
82+
/// # Usage
83+
/// ```ignore
84+
/// // Find the first ancestor that matches a condition
85+
/// let parent = self.ancestors()
86+
/// .find(|p| matches!(p, AstNodes::ForStatement(_)))
87+
/// .unwrap();
88+
///
89+
/// // Get the nth ancestor
90+
/// let great_grandparent = self.ancestors().nth(3);
91+
///
92+
/// // Check if any ancestor is a specific type
93+
/// let in_arrow_fn = self.ancestors()
94+
/// .any(|p| matches!(p, AstNodes::ArrowFunctionExpression(_)));
95+
/// ```
96+
pub fn ancestors(&self) -> impl Iterator<Item = &AstNodes<'_>> {
97+
self.parent.ancestors()
98+
}
99+
100+
/// Returns the grandparent node (parent's parent).
101+
///
102+
/// This is a convenience method equivalent to `self.parent.parent()`.
103+
pub fn grand_parent(&self) -> &AstNodes<'_> {
104+
self.parent.parent()
105+
}
106+
}
107+
66108
impl<'a> AstNode<'a, Program<'a>> {
67109
pub fn new(inner: &'a Program<'a>, parent: &'a AstNodes<'a>, allocator: &'a Allocator) -> Self {
68110
AstNode { inner, parent, allocator, following_span: None }

crates/oxc_formatter/src/parentheses/expression.rs

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> {
8484
matches!(self.parent, AstNodes::ForOfStatement(stmt) if !stmt.r#await && stmt.left.span().contains_inclusive(self.span))
8585
}
8686
"let" => {
87-
let mut parent = self.parent;
88-
loop {
87+
// Walk up ancestors to find the relevant context for `let` keyword
88+
for parent in self.ancestors() {
8989
match parent {
90-
AstNodes::Program(_) | AstNodes::ExpressionStatement(_) => return false,
90+
AstNodes::ExpressionStatement(_) => return false,
9191
AstNodes::ForOfStatement(stmt) => {
9292
return stmt.left.span().contains_inclusive(self.span);
9393
}
9494
AstNodes::TSSatisfiesExpression(expr) => {
9595
return expr.expression.span() == self.span();
9696
}
97-
_ => parent = parent.parent(),
97+
_ => {}
9898
}
9999
}
100-
unreachable!()
100+
false
101101
}
102102
name => {
103103
// <https://github.com/prettier/prettier/blob/7584432401a47a26943dd7a9ca9a8e032ead7285/src/language-js/needs-parens.js#L123-L133>
@@ -131,7 +131,7 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> {
131131
matches!(
132132
parent, AstNodes::ExpressionStatement(stmt) if
133133
!matches!(
134-
stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow)
134+
stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow)
135135
if arrow.expression()
136136
)
137137
)
@@ -392,8 +392,9 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, BinaryExpression<'a>> {
392392

393393
/// Add parentheses if the `in` is inside of a `for` initializer (see tests).
394394
fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
395-
let mut parent = expr.parent;
396-
loop {
395+
let mut ancestors = expr.ancestors();
396+
397+
while let Some(parent) = ancestors.next() {
397398
match parent {
398399
AstNodes::ExpressionStatement(stmt) => {
399400
let grand_parent = parent.parent();
@@ -404,7 +405,13 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
404405
grand_grand_parent,
405406
AstNodes::ArrowFunctionExpression(arrow) if arrow.expression()
406407
) {
407-
parent = grand_grand_parent;
408+
// Skip ahead to grand_grand_parent by consuming ancestors
409+
// until we reach it
410+
for ancestor in ancestors.by_ref() {
411+
if core::ptr::eq(ancestor, grand_grand_parent) {
412+
break;
413+
}
414+
}
408415
continue;
409416
}
410417
}
@@ -423,11 +430,11 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
423430
AstNodes::Program(_) => {
424431
return false;
425432
}
426-
_ => {
427-
parent = parent.parent();
428-
}
433+
_ => {}
429434
}
430435
}
436+
437+
false
431438
}
432439

433440
impl<'a> NeedsParentheses<'a> for AstNode<'a, PrivateInExpression<'a>> {
@@ -546,25 +553,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, AssignmentExpression<'a>> {
546553
// - `a = 1, b = 2` in for loops don't need parens
547554
// - `(a = 1, b = 2)` elsewhere usually need parens
548555
AstNodes::SequenceExpression(sequence) => {
549-
let mut current_parent = self.parent;
550-
loop {
551-
match current_parent {
552-
AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_) => {
553-
current_parent = current_parent.parent();
554-
}
555-
AstNodes::ForStatement(for_stmt) => {
556-
let is_initializer = for_stmt
557-
.init
558-
.as_ref()
559-
.is_some_and(|init| init.span().contains_inclusive(self.span()));
560-
let is_update = for_stmt.update.as_ref().is_some_and(|update| {
561-
update.span().contains_inclusive(self.span())
562-
});
563-
return !(is_initializer || is_update);
564-
}
565-
_ => break,
556+
// Skip through SequenceExpression and ParenthesizedExpression ancestors
557+
if let Some(ancestor) = self.ancestors().find(|p| {
558+
!matches!(p, AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_))
559+
}) && let AstNodes::ForStatement(for_stmt) = ancestor {
560+
let is_initializer = for_stmt
561+
.init
562+
.as_ref()
563+
.is_some_and(|init| init.span().contains_inclusive(self.span()));
564+
let is_update = for_stmt.update.as_ref().is_some_and(|update| {
565+
update.span().contains_inclusive(self.span())
566+
});
567+
return !(is_initializer || is_update);
566568
}
567-
}
569+
568570
true
569571
}
570572
// `interface { [a = 1]; }` and `class { [a = 1]; }` not need parens
@@ -620,8 +622,8 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, SequenceExpression<'a>> {
620622
}
621623
}
622624

623-
impl<'a> NeedsParentheses<'a> for AstNode<'a, AwaitExpression<'a>> {
624-
fn needs_parentheses(&self, f: &Formatter<'_, 'a>) -> bool {
625+
impl NeedsParentheses<'_> for AstNode<'_, AwaitExpression<'_>> {
626+
fn needs_parentheses(&self, f: &Formatter<'_, '_>) -> bool {
625627
if f.comments().is_type_cast_node(self) {
626628
return false;
627629
}
@@ -977,14 +979,15 @@ pub enum FirstInStatementMode {
977979
/// the left most node or reached a statement.
978980
fn is_first_in_statement(
979981
mut current_span: Span,
980-
mut parent: &AstNodes<'_>,
982+
parent: &AstNodes<'_>,
981983
mode: FirstInStatementMode,
982984
) -> bool {
983-
let mut is_not_first_iteration = false;
984-
loop {
985-
match parent {
985+
for (index, ancestor) in parent.ancestors().enumerate() {
986+
let is_not_first_iteration = index > 0;
987+
988+
match ancestor {
986989
AstNodes::ExpressionStatement(stmt) => {
987-
if matches!(stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
990+
if matches!(stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
988991
{
989992
if mode == FirstInStatementMode::ExpressionStatementOrArrow {
990993
if is_not_first_iteration
@@ -1051,9 +1054,7 @@ fn is_first_in_statement(
10511054
}
10521055
_ => break,
10531056
}
1054-
current_span = parent.span();
1055-
parent = parent.parent();
1056-
is_not_first_iteration = true;
1057+
current_span = ancestor.span();
10571058
}
10581059

10591060
false

crates/oxc_formatter/src/utils/call_expression.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ pub fn is_test_call_expression(call: &AstNode<CallExpression<'_>>) -> bool {
3737
match (args.next(), args.next(), args.next()) {
3838
(Some(argument), None, None) if arguments.len() == 1 => {
3939
if is_angular_test_wrapper(call) && {
40-
if let AstNodes::CallExpression(call) = call.parent.parent() {
40+
if let AstNodes::CallExpression(call) = call.grand_parent() {
4141
is_test_call_expression(call)
4242
} else {
4343
false

crates/oxc_formatter/src/utils/member_chain/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ impl<'a, 'b> MemberChain<'a, 'b> {
9696
is_factory(&identifier.name) ||
9797
// If an identifier has a name that is shorter than the tab with, then we join it with the "head"
9898
(matches!(parent, AstNodes::ExpressionStatement(stmt) if {
99-
if let AstNodes::ArrowFunctionExpression(arrow) = stmt.parent.parent() {
99+
if let AstNodes::ArrowFunctionExpression(arrow) = stmt.grand_parent() {
100100
!arrow.expression
101101
} else {
102102
true

crates/oxc_formatter/src/write/call_arguments.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl<'a> Format<'a> for AstNode<'a, ArenaVec<'a, Argument<'a>>> {
115115
});
116116

117117
if has_empty_line
118-
|| (!matches!(self.parent.parent(), AstNodes::Decorator(_))
118+
|| (!matches!(self.grand_parent(), AstNodes::Decorator(_))
119119
&& is_function_composition_args(self))
120120
{
121121
return format_all_args_broken_out(self, true, f);

crates/oxc_formatter/src/write/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ impl<'a> Format<'a> for FormatClass<'a, '_> {
425425
}
426426
});
427427

428-
if matches!(extends.parent.parent(), AstNodes::AssignmentExpression(_)) {
428+
if matches!(extends.grand_parent(), AstNodes::AssignmentExpression(_)) {
429429
if has_trailing_comments {
430430
write!(f, [text("("), &content, text(")")])
431431
} else {

crates/oxc_formatter/src/write/jsx/element.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,18 @@ impl<'a> Format<'a> for AnyJsxTagWithChildren<'a, '_> {
181181
/// </div>;
182182
/// ```
183183
pub fn should_expand(mut parent: &AstNodes<'_>) -> bool {
184-
if matches!(parent, AstNodes::ExpressionStatement(_)) {
184+
if let AstNodes::ExpressionStatement(stmt) = parent {
185185
// If the parent is a JSXExpressionContainer, we need to check its parent
186186
// to determine if it should expand.
187-
parent = parent.parent().parent();
187+
parent = stmt.grand_parent();
188188
}
189189
let maybe_jsx_expression_child = match parent {
190190
AstNodes::ArrowFunctionExpression(arrow) if arrow.expression => match arrow.parent {
191191
// Argument
192192
AstNodes::Argument(argument)
193193
if matches!(argument.parent, AstNodes::CallExpression(_)) =>
194194
{
195-
argument.parent.parent()
195+
argument.grand_parent()
196196
}
197197
// Callee
198198
AstNodes::CallExpression(call) => call.parent,

0 commit comments

Comments
 (0)