Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 117 additions & 110 deletions crates/oxc_linter/src/rules/jest/expect_expect.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use cow_utils::CowUtils;
use lazy_regex::Regex;
use rustc_hash::FxHashSet;

use oxc_ast::{
AstKind,
ast::{CallExpression, Expression, Statement},
ast::{CallExpression, Expression, FormalParameter, Function, Statement},
};
use oxc_ast_visit::{Visit, walk};
use oxc_diagnostics::OxcDiagnostic;
use oxc_macros::declare_oxc_lint;
use oxc_span::{CompactStr, GetSpan, Span};
use rustc_hash::FxHashSet;
use oxc_syntax::scope::ScopeFlags;

use crate::{
ast_util::get_declaration_of_variable,
Expand Down Expand Up @@ -162,141 +165,145 @@ fn run<'a>(
}
}

// Record visited nodes to avoid infinite loop.
let mut visited: FxHashSet<Span> = FxHashSet::default();

let has_assert_function = if ctx.frameworks().is_vitest() {
check_arguments(call_expr, &rule.assert_function_names_vitest, &mut visited, ctx)
let assert_function_names = if ctx.frameworks().is_vitest() {
&rule.assert_function_names_vitest
} else {
check_arguments(call_expr, &rule.assert_function_names_jest, &mut visited, ctx)
&rule.assert_function_names_jest
};

if !has_assert_function {
let mut visitor = AssertionVisitor::new(ctx, assert_function_names);

// Visit each argument of the test call
for argument in &call_expr.arguments {
if let Some(expr) = argument.as_expression() {
visitor.check_expression(expr);
if visitor.found_assertion {
return;
}
}
}

if !visitor.found_assertion {
ctx.diagnostic(expect_expect_diagnostic(call_expr.callee.span()));
}
}
}
}

fn check_arguments<'a>(
call_expr: &'a CallExpression<'a>,
assert_function_names: &[CompactStr],
visited: &mut FxHashSet<Span>,
ctx: &LintContext<'a>,
) -> bool {
for argument in &call_expr.arguments {
if let Some(expr) = argument.as_expression() {
if check_assert_function_used(expr, assert_function_names, visited, ctx) {
return true;
}
}
}
false
struct AssertionVisitor<'a, 'b> {
ctx: &'b LintContext<'a>,
assert_function_names: &'b [CompactStr],
visited: FxHashSet<Span>,
found_assertion: bool,
}

fn check_assert_function_used<'a>(
expr: &'a Expression<'a>,
assert_function_names: &[CompactStr],
visited: &mut FxHashSet<Span>,
ctx: &LintContext<'a>,
) -> bool {
// If we have visited this node before and didn't find any assert function, we can return
// `false` to avoid infinite loop.
//
// ```javascript
// test("should fail", () => {
// function foo() {
// if (condition) {
// foo()
// }
// }
// foo()
// })
// ```
if !visited.insert(expr.span()) {
return false;
impl<'a, 'b> AssertionVisitor<'a, 'b> {
fn new(ctx: &'b LintContext<'a>, assert_function_names: &'b [CompactStr]) -> Self {
Self { ctx, assert_function_names, visited: FxHashSet::default(), found_assertion: false }
}

match expr {
Expression::FunctionExpression(fn_expr) => {
let body = &fn_expr.body;
if let Some(body) = body {
return check_statements(&body.statements, assert_function_names, visited, ctx);
}
}
Expression::ArrowFunctionExpression(arrow_expr) => {
let body = &arrow_expr.body;
return check_statements(&body.statements, assert_function_names, visited, ctx);
fn check_expression(&mut self, expr: &Expression<'a>) {
// Avoid infinite loops by tracking visited expressions
if !self.visited.insert(expr.span()) {
return;
}
Expression::CallExpression(call_expr) => {
let name = get_node_name(&call_expr.callee);
if matches_assert_function_name(&name, assert_function_names) {
return true;
}

// If CallExpression is not an assert function, we need to check its arguments, it may trigger
// another assert function.
// ```javascript
// it('should pass', () => somePromise().then(() => expect(true).toBeDefined()))
// ```
let has_assert_function =
check_arguments(call_expr, assert_function_names, visited, ctx);

return has_assert_function;
}
Expression::Identifier(ident) => {
let Some(node) = get_declaration_of_variable(ident, ctx) else {
return false;
};
let AstKind::Function(function) = node.kind() else {
return false;
};
let Some(body) = &function.body else {
return false;
};
return check_statements(&body.statements, assert_function_names, visited, ctx);
}
Expression::AwaitExpression(expr) => {
return check_assert_function_used(&expr.argument, assert_function_names, visited, ctx);
}
Expression::ArrayExpression(array_expr) => {
for element in &array_expr.elements {
if let Some(element_expr) = element.as_expression() {
if check_assert_function_used(element_expr, assert_function_names, visited, ctx)
{
return true;
match expr {
Expression::FunctionExpression(fn_expr) => {
if let Some(body) = &fn_expr.body {
self.visit_function_body(body);
}
}
Expression::ArrowFunctionExpression(arrow_expr) => {
self.visit_function_body(&arrow_expr.body);
}
Expression::CallExpression(call_expr) => {
self.visit_call_expression(call_expr);
}
Expression::Identifier(ident) => {
self.check_identifier(ident);
}
Expression::AwaitExpression(expr) => {
self.check_expression(&expr.argument);
}
Expression::ArrayExpression(array_expr) => {
for element in &array_expr.elements {
if let Some(element_expr) = element.as_expression() {
self.check_expression(element_expr);
if self.found_assertion {
return;
}
}
}
}
_ => {}
}
_ => {}
}

false
fn check_identifier(&mut self, ident: &oxc_ast::ast::IdentifierReference<'a>) {
let Some(node) = get_declaration_of_variable(ident, self.ctx) else {
return;
};
let AstKind::Function(function) = node.kind() else {
return;
};
if let Some(body) = &function.body {
self.visit_function_body(body);
}
}
}

fn check_statements<'a>(
statements: &'a oxc_allocator::Vec<Statement<'a>>,
assert_function_names: &[CompactStr],
visited: &mut FxHashSet<Span>,
ctx: &LintContext<'a>,
) -> bool {
statements.iter().any(|statement| match statement {
Statement::ExpressionStatement(expr_stmt) => {
check_assert_function_used(&expr_stmt.expression, assert_function_names, visited, ctx)
impl<'a> Visit<'a> for AssertionVisitor<'a, '_> {
fn visit_call_expression(&mut self, call_expr: &CallExpression<'a>) {
let name = get_node_name(&call_expr.callee);
if matches_assert_function_name(&name, self.assert_function_names) {
self.found_assertion = true;
return;
}
Statement::BlockStatement(block_stmt) => {
check_statements(&block_stmt.body, assert_function_names, visited, ctx)

for argument in &call_expr.arguments {
if let Some(expr) = argument.as_expression() {
self.check_expression(expr);
if self.found_assertion {
return;
}
}
}
Statement::IfStatement(if_stmt) => {
if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent {
check_statements(&block_stmt.body, assert_function_names, visited, ctx)
} else {
false

walk::walk_call_expression(self, call_expr);
}

fn visit_expression_statement(&mut self, stmt: &oxc_ast::ast::ExpressionStatement<'a>) {
self.check_expression(&stmt.expression);
if !self.found_assertion {
walk::walk_expression_statement(self, stmt);
}
}

fn visit_block_statement(&mut self, block: &oxc_ast::ast::BlockStatement<'a>) {
for stmt in &block.body {
self.visit_statement(stmt);
if self.found_assertion {
return;
}
}
_ => false,
})
}

fn visit_if_statement(&mut self, if_stmt: &oxc_ast::ast::IfStatement<'a>) {
if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent {
self.visit_block_statement(block_stmt);
}
if self.found_assertion {
return;
}
if let Some(alternate) = &if_stmt.alternate {
self.visit_statement(alternate);
}
}

fn visit_function(&mut self, _func: &Function<'a>, _flags: ScopeFlags) {}

fn visit_formal_parameter(&mut self, _param: &FormalParameter<'a>) {}
}

/// Checks if node names returned by getNodeName matches any of the given star patterns
Expand Down
Loading