Skip to content

Commit 106e7a7

Browse files
committed
refactor(linter/expect-expect): use visitor pattern to detect expect calls (#12906)
fise
1 parent 84794eb commit 106e7a7

File tree

1 file changed

+117
-110
lines changed

1 file changed

+117
-110
lines changed

crates/oxc_linter/src/rules/jest/expect_expect.rs

Lines changed: 117 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
use cow_utils::CowUtils;
22
use lazy_regex::Regex;
3+
use rustc_hash::FxHashSet;
4+
35
use oxc_ast::{
46
AstKind,
5-
ast::{CallExpression, Expression, Statement},
7+
ast::{CallExpression, Expression, FormalParameter, Function, Statement},
68
};
9+
use oxc_ast_visit::{Visit, walk};
710
use oxc_diagnostics::OxcDiagnostic;
811
use oxc_macros::declare_oxc_lint;
912
use oxc_span::{CompactStr, GetSpan, Span};
10-
use rustc_hash::FxHashSet;
13+
use oxc_syntax::scope::ScopeFlags;
1114

1215
use crate::{
1316
ast_util::get_declaration_of_variable,
@@ -162,141 +165,145 @@ fn run<'a>(
162165
}
163166
}
164167

165-
// Record visited nodes to avoid infinite loop.
166-
let mut visited: FxHashSet<Span> = FxHashSet::default();
167-
168-
let has_assert_function = if ctx.frameworks().is_vitest() {
169-
check_arguments(call_expr, &rule.assert_function_names_vitest, &mut visited, ctx)
168+
let assert_function_names = if ctx.frameworks().is_vitest() {
169+
&rule.assert_function_names_vitest
170170
} else {
171-
check_arguments(call_expr, &rule.assert_function_names_jest, &mut visited, ctx)
171+
&rule.assert_function_names_jest
172172
};
173173

174-
if !has_assert_function {
174+
let mut visitor = AssertionVisitor::new(ctx, assert_function_names);
175+
176+
// Visit each argument of the test call
177+
for argument in &call_expr.arguments {
178+
if let Some(expr) = argument.as_expression() {
179+
visitor.check_expression(expr);
180+
if visitor.found_assertion {
181+
return;
182+
}
183+
}
184+
}
185+
186+
if !visitor.found_assertion {
175187
ctx.diagnostic(expect_expect_diagnostic(call_expr.callee.span()));
176188
}
177189
}
178190
}
179191
}
180192

181-
fn check_arguments<'a>(
182-
call_expr: &'a CallExpression<'a>,
183-
assert_function_names: &[CompactStr],
184-
visited: &mut FxHashSet<Span>,
185-
ctx: &LintContext<'a>,
186-
) -> bool {
187-
for argument in &call_expr.arguments {
188-
if let Some(expr) = argument.as_expression() {
189-
if check_assert_function_used(expr, assert_function_names, visited, ctx) {
190-
return true;
191-
}
192-
}
193-
}
194-
false
193+
struct AssertionVisitor<'a, 'b> {
194+
ctx: &'b LintContext<'a>,
195+
assert_function_names: &'b [CompactStr],
196+
visited: FxHashSet<Span>,
197+
found_assertion: bool,
195198
}
196199

197-
fn check_assert_function_used<'a>(
198-
expr: &'a Expression<'a>,
199-
assert_function_names: &[CompactStr],
200-
visited: &mut FxHashSet<Span>,
201-
ctx: &LintContext<'a>,
202-
) -> bool {
203-
// If we have visited this node before and didn't find any assert function, we can return
204-
// `false` to avoid infinite loop.
205-
//
206-
// ```javascript
207-
// test("should fail", () => {
208-
// function foo() {
209-
// if (condition) {
210-
// foo()
211-
// }
212-
// }
213-
// foo()
214-
// })
215-
// ```
216-
if !visited.insert(expr.span()) {
217-
return false;
200+
impl<'a, 'b> AssertionVisitor<'a, 'b> {
201+
fn new(ctx: &'b LintContext<'a>, assert_function_names: &'b [CompactStr]) -> Self {
202+
Self { ctx, assert_function_names, visited: FxHashSet::default(), found_assertion: false }
218203
}
219204

220-
match expr {
221-
Expression::FunctionExpression(fn_expr) => {
222-
let body = &fn_expr.body;
223-
if let Some(body) = body {
224-
return check_statements(&body.statements, assert_function_names, visited, ctx);
225-
}
226-
}
227-
Expression::ArrowFunctionExpression(arrow_expr) => {
228-
let body = &arrow_expr.body;
229-
return check_statements(&body.statements, assert_function_names, visited, ctx);
205+
fn check_expression(&mut self, expr: &Expression<'a>) {
206+
// Avoid infinite loops by tracking visited expressions
207+
if !self.visited.insert(expr.span()) {
208+
return;
230209
}
231-
Expression::CallExpression(call_expr) => {
232-
let name = get_node_name(&call_expr.callee);
233-
if matches_assert_function_name(&name, assert_function_names) {
234-
return true;
235-
}
236-
237-
// If CallExpression is not an assert function, we need to check its arguments, it may trigger
238-
// another assert function.
239-
// ```javascript
240-
// it('should pass', () => somePromise().then(() => expect(true).toBeDefined()))
241-
// ```
242-
let has_assert_function =
243-
check_arguments(call_expr, assert_function_names, visited, ctx);
244210

245-
return has_assert_function;
246-
}
247-
Expression::Identifier(ident) => {
248-
let Some(node) = get_declaration_of_variable(ident, ctx) else {
249-
return false;
250-
};
251-
let AstKind::Function(function) = node.kind() else {
252-
return false;
253-
};
254-
let Some(body) = &function.body else {
255-
return false;
256-
};
257-
return check_statements(&body.statements, assert_function_names, visited, ctx);
258-
}
259-
Expression::AwaitExpression(expr) => {
260-
return check_assert_function_used(&expr.argument, assert_function_names, visited, ctx);
261-
}
262-
Expression::ArrayExpression(array_expr) => {
263-
for element in &array_expr.elements {
264-
if let Some(element_expr) = element.as_expression() {
265-
if check_assert_function_used(element_expr, assert_function_names, visited, ctx)
266-
{
267-
return true;
211+
match expr {
212+
Expression::FunctionExpression(fn_expr) => {
213+
if let Some(body) = &fn_expr.body {
214+
self.visit_function_body(body);
215+
}
216+
}
217+
Expression::ArrowFunctionExpression(arrow_expr) => {
218+
self.visit_function_body(&arrow_expr.body);
219+
}
220+
Expression::CallExpression(call_expr) => {
221+
self.visit_call_expression(call_expr);
222+
}
223+
Expression::Identifier(ident) => {
224+
self.check_identifier(ident);
225+
}
226+
Expression::AwaitExpression(expr) => {
227+
self.check_expression(&expr.argument);
228+
}
229+
Expression::ArrayExpression(array_expr) => {
230+
for element in &array_expr.elements {
231+
if let Some(element_expr) = element.as_expression() {
232+
self.check_expression(element_expr);
233+
if self.found_assertion {
234+
return;
235+
}
268236
}
269237
}
270238
}
239+
_ => {}
271240
}
272-
_ => {}
273241
}
274242

275-
false
243+
fn check_identifier(&mut self, ident: &oxc_ast::ast::IdentifierReference<'a>) {
244+
let Some(node) = get_declaration_of_variable(ident, self.ctx) else {
245+
return;
246+
};
247+
let AstKind::Function(function) = node.kind() else {
248+
return;
249+
};
250+
if let Some(body) = &function.body {
251+
self.visit_function_body(body);
252+
}
253+
}
276254
}
277255

278-
fn check_statements<'a>(
279-
statements: &'a oxc_allocator::Vec<Statement<'a>>,
280-
assert_function_names: &[CompactStr],
281-
visited: &mut FxHashSet<Span>,
282-
ctx: &LintContext<'a>,
283-
) -> bool {
284-
statements.iter().any(|statement| match statement {
285-
Statement::ExpressionStatement(expr_stmt) => {
286-
check_assert_function_used(&expr_stmt.expression, assert_function_names, visited, ctx)
256+
impl<'a> Visit<'a> for AssertionVisitor<'a, '_> {
257+
fn visit_call_expression(&mut self, call_expr: &CallExpression<'a>) {
258+
let name = get_node_name(&call_expr.callee);
259+
if matches_assert_function_name(&name, self.assert_function_names) {
260+
self.found_assertion = true;
261+
return;
287262
}
288-
Statement::BlockStatement(block_stmt) => {
289-
check_statements(&block_stmt.body, assert_function_names, visited, ctx)
263+
264+
for argument in &call_expr.arguments {
265+
if let Some(expr) = argument.as_expression() {
266+
self.check_expression(expr);
267+
if self.found_assertion {
268+
return;
269+
}
270+
}
290271
}
291-
Statement::IfStatement(if_stmt) => {
292-
if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent {
293-
check_statements(&block_stmt.body, assert_function_names, visited, ctx)
294-
} else {
295-
false
272+
273+
walk::walk_call_expression(self, call_expr);
274+
}
275+
276+
fn visit_expression_statement(&mut self, stmt: &oxc_ast::ast::ExpressionStatement<'a>) {
277+
self.check_expression(&stmt.expression);
278+
if !self.found_assertion {
279+
walk::walk_expression_statement(self, stmt);
280+
}
281+
}
282+
283+
fn visit_block_statement(&mut self, block: &oxc_ast::ast::BlockStatement<'a>) {
284+
for stmt in &block.body {
285+
self.visit_statement(stmt);
286+
if self.found_assertion {
287+
return;
296288
}
297289
}
298-
_ => false,
299-
})
290+
}
291+
292+
fn visit_if_statement(&mut self, if_stmt: &oxc_ast::ast::IfStatement<'a>) {
293+
if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent {
294+
self.visit_block_statement(block_stmt);
295+
}
296+
if self.found_assertion {
297+
return;
298+
}
299+
if let Some(alternate) = &if_stmt.alternate {
300+
self.visit_statement(alternate);
301+
}
302+
}
303+
304+
fn visit_function(&mut self, _func: &Function<'a>, _flags: ScopeFlags) {}
305+
306+
fn visit_formal_parameter(&mut self, _param: &FormalParameter<'a>) {}
300307
}
301308

302309
/// Checks if node names returned by getNodeName matches any of the given star patterns

0 commit comments

Comments
 (0)