|
1 | 1 | use cow_utils::CowUtils; |
2 | 2 | use lazy_regex::Regex; |
| 3 | +use rustc_hash::FxHashSet; |
| 4 | + |
3 | 5 | use oxc_ast::{ |
4 | 6 | AstKind, |
5 | | - ast::{CallExpression, Expression, Statement}, |
| 7 | + ast::{CallExpression, Expression, FormalParameter, Function, Statement}, |
6 | 8 | }; |
| 9 | +use oxc_ast_visit::{Visit, walk}; |
7 | 10 | use oxc_diagnostics::OxcDiagnostic; |
8 | 11 | use oxc_macros::declare_oxc_lint; |
9 | 12 | use oxc_span::{CompactStr, GetSpan, Span}; |
10 | | -use rustc_hash::FxHashSet; |
| 13 | +use oxc_syntax::scope::ScopeFlags; |
11 | 14 |
|
12 | 15 | use crate::{ |
13 | 16 | ast_util::get_declaration_of_variable, |
@@ -162,141 +165,145 @@ fn run<'a>( |
162 | 165 | } |
163 | 166 | } |
164 | 167 |
|
165 | | - // Record visited nodes to avoid infinite loop. |
166 | | - let mut visited: FxHashSet<Span> = FxHashSet::default(); |
167 | | - |
168 | | - let has_assert_function = if ctx.frameworks().is_vitest() { |
169 | | - check_arguments(call_expr, &rule.assert_function_names_vitest, &mut visited, ctx) |
| 168 | + let assert_function_names = if ctx.frameworks().is_vitest() { |
| 169 | + &rule.assert_function_names_vitest |
170 | 170 | } else { |
171 | | - check_arguments(call_expr, &rule.assert_function_names_jest, &mut visited, ctx) |
| 171 | + &rule.assert_function_names_jest |
172 | 172 | }; |
173 | 173 |
|
174 | | - if !has_assert_function { |
| 174 | + let mut visitor = AssertionVisitor::new(ctx, assert_function_names); |
| 175 | + |
| 176 | + // Visit each argument of the test call |
| 177 | + for argument in &call_expr.arguments { |
| 178 | + if let Some(expr) = argument.as_expression() { |
| 179 | + visitor.check_expression(expr); |
| 180 | + if visitor.found_assertion { |
| 181 | + return; |
| 182 | + } |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + if !visitor.found_assertion { |
175 | 187 | ctx.diagnostic(expect_expect_diagnostic(call_expr.callee.span())); |
176 | 188 | } |
177 | 189 | } |
178 | 190 | } |
179 | 191 | } |
180 | 192 |
|
181 | | -fn check_arguments<'a>( |
182 | | - call_expr: &'a CallExpression<'a>, |
183 | | - assert_function_names: &[CompactStr], |
184 | | - visited: &mut FxHashSet<Span>, |
185 | | - ctx: &LintContext<'a>, |
186 | | -) -> bool { |
187 | | - for argument in &call_expr.arguments { |
188 | | - if let Some(expr) = argument.as_expression() { |
189 | | - if check_assert_function_used(expr, assert_function_names, visited, ctx) { |
190 | | - return true; |
191 | | - } |
192 | | - } |
193 | | - } |
194 | | - false |
| 193 | +struct AssertionVisitor<'a, 'b> { |
| 194 | + ctx: &'b LintContext<'a>, |
| 195 | + assert_function_names: &'b [CompactStr], |
| 196 | + visited: FxHashSet<Span>, |
| 197 | + found_assertion: bool, |
195 | 198 | } |
196 | 199 |
|
197 | | -fn check_assert_function_used<'a>( |
198 | | - expr: &'a Expression<'a>, |
199 | | - assert_function_names: &[CompactStr], |
200 | | - visited: &mut FxHashSet<Span>, |
201 | | - ctx: &LintContext<'a>, |
202 | | -) -> bool { |
203 | | - // If we have visited this node before and didn't find any assert function, we can return |
204 | | - // `false` to avoid infinite loop. |
205 | | - // |
206 | | - // ```javascript |
207 | | - // test("should fail", () => { |
208 | | - // function foo() { |
209 | | - // if (condition) { |
210 | | - // foo() |
211 | | - // } |
212 | | - // } |
213 | | - // foo() |
214 | | - // }) |
215 | | - // ``` |
216 | | - if !visited.insert(expr.span()) { |
217 | | - return false; |
| 200 | +impl<'a, 'b> AssertionVisitor<'a, 'b> { |
| 201 | + fn new(ctx: &'b LintContext<'a>, assert_function_names: &'b [CompactStr]) -> Self { |
| 202 | + Self { ctx, assert_function_names, visited: FxHashSet::default(), found_assertion: false } |
218 | 203 | } |
219 | 204 |
|
220 | | - match expr { |
221 | | - Expression::FunctionExpression(fn_expr) => { |
222 | | - let body = &fn_expr.body; |
223 | | - if let Some(body) = body { |
224 | | - return check_statements(&body.statements, assert_function_names, visited, ctx); |
225 | | - } |
226 | | - } |
227 | | - Expression::ArrowFunctionExpression(arrow_expr) => { |
228 | | - let body = &arrow_expr.body; |
229 | | - return check_statements(&body.statements, assert_function_names, visited, ctx); |
| 205 | + fn check_expression(&mut self, expr: &Expression<'a>) { |
| 206 | + // Avoid infinite loops by tracking visited expressions |
| 207 | + if !self.visited.insert(expr.span()) { |
| 208 | + return; |
230 | 209 | } |
231 | | - Expression::CallExpression(call_expr) => { |
232 | | - let name = get_node_name(&call_expr.callee); |
233 | | - if matches_assert_function_name(&name, assert_function_names) { |
234 | | - return true; |
235 | | - } |
236 | | - |
237 | | - // If CallExpression is not an assert function, we need to check its arguments, it may trigger |
238 | | - // another assert function. |
239 | | - // ```javascript |
240 | | - // it('should pass', () => somePromise().then(() => expect(true).toBeDefined())) |
241 | | - // ``` |
242 | | - let has_assert_function = |
243 | | - check_arguments(call_expr, assert_function_names, visited, ctx); |
244 | 210 |
|
245 | | - return has_assert_function; |
246 | | - } |
247 | | - Expression::Identifier(ident) => { |
248 | | - let Some(node) = get_declaration_of_variable(ident, ctx) else { |
249 | | - return false; |
250 | | - }; |
251 | | - let AstKind::Function(function) = node.kind() else { |
252 | | - return false; |
253 | | - }; |
254 | | - let Some(body) = &function.body else { |
255 | | - return false; |
256 | | - }; |
257 | | - return check_statements(&body.statements, assert_function_names, visited, ctx); |
258 | | - } |
259 | | - Expression::AwaitExpression(expr) => { |
260 | | - return check_assert_function_used(&expr.argument, assert_function_names, visited, ctx); |
261 | | - } |
262 | | - Expression::ArrayExpression(array_expr) => { |
263 | | - for element in &array_expr.elements { |
264 | | - if let Some(element_expr) = element.as_expression() { |
265 | | - if check_assert_function_used(element_expr, assert_function_names, visited, ctx) |
266 | | - { |
267 | | - return true; |
| 211 | + match expr { |
| 212 | + Expression::FunctionExpression(fn_expr) => { |
| 213 | + if let Some(body) = &fn_expr.body { |
| 214 | + self.visit_function_body(body); |
| 215 | + } |
| 216 | + } |
| 217 | + Expression::ArrowFunctionExpression(arrow_expr) => { |
| 218 | + self.visit_function_body(&arrow_expr.body); |
| 219 | + } |
| 220 | + Expression::CallExpression(call_expr) => { |
| 221 | + self.visit_call_expression(call_expr); |
| 222 | + } |
| 223 | + Expression::Identifier(ident) => { |
| 224 | + self.check_identifier(ident); |
| 225 | + } |
| 226 | + Expression::AwaitExpression(expr) => { |
| 227 | + self.check_expression(&expr.argument); |
| 228 | + } |
| 229 | + Expression::ArrayExpression(array_expr) => { |
| 230 | + for element in &array_expr.elements { |
| 231 | + if let Some(element_expr) = element.as_expression() { |
| 232 | + self.check_expression(element_expr); |
| 233 | + if self.found_assertion { |
| 234 | + return; |
| 235 | + } |
268 | 236 | } |
269 | 237 | } |
270 | 238 | } |
| 239 | + _ => {} |
271 | 240 | } |
272 | | - _ => {} |
273 | 241 | } |
274 | 242 |
|
275 | | - false |
| 243 | + fn check_identifier(&mut self, ident: &oxc_ast::ast::IdentifierReference<'a>) { |
| 244 | + let Some(node) = get_declaration_of_variable(ident, self.ctx) else { |
| 245 | + return; |
| 246 | + }; |
| 247 | + let AstKind::Function(function) = node.kind() else { |
| 248 | + return; |
| 249 | + }; |
| 250 | + if let Some(body) = &function.body { |
| 251 | + self.visit_function_body(body); |
| 252 | + } |
| 253 | + } |
276 | 254 | } |
277 | 255 |
|
278 | | -fn check_statements<'a>( |
279 | | - statements: &'a oxc_allocator::Vec<Statement<'a>>, |
280 | | - assert_function_names: &[CompactStr], |
281 | | - visited: &mut FxHashSet<Span>, |
282 | | - ctx: &LintContext<'a>, |
283 | | -) -> bool { |
284 | | - statements.iter().any(|statement| match statement { |
285 | | - Statement::ExpressionStatement(expr_stmt) => { |
286 | | - check_assert_function_used(&expr_stmt.expression, assert_function_names, visited, ctx) |
| 256 | +impl<'a> Visit<'a> for AssertionVisitor<'a, '_> { |
| 257 | + fn visit_call_expression(&mut self, call_expr: &CallExpression<'a>) { |
| 258 | + let name = get_node_name(&call_expr.callee); |
| 259 | + if matches_assert_function_name(&name, self.assert_function_names) { |
| 260 | + self.found_assertion = true; |
| 261 | + return; |
287 | 262 | } |
288 | | - Statement::BlockStatement(block_stmt) => { |
289 | | - check_statements(&block_stmt.body, assert_function_names, visited, ctx) |
| 263 | + |
| 264 | + for argument in &call_expr.arguments { |
| 265 | + if let Some(expr) = argument.as_expression() { |
| 266 | + self.check_expression(expr); |
| 267 | + if self.found_assertion { |
| 268 | + return; |
| 269 | + } |
| 270 | + } |
290 | 271 | } |
291 | | - Statement::IfStatement(if_stmt) => { |
292 | | - if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent { |
293 | | - check_statements(&block_stmt.body, assert_function_names, visited, ctx) |
294 | | - } else { |
295 | | - false |
| 272 | + |
| 273 | + walk::walk_call_expression(self, call_expr); |
| 274 | + } |
| 275 | + |
| 276 | + fn visit_expression_statement(&mut self, stmt: &oxc_ast::ast::ExpressionStatement<'a>) { |
| 277 | + self.check_expression(&stmt.expression); |
| 278 | + if !self.found_assertion { |
| 279 | + walk::walk_expression_statement(self, stmt); |
| 280 | + } |
| 281 | + } |
| 282 | + |
| 283 | + fn visit_block_statement(&mut self, block: &oxc_ast::ast::BlockStatement<'a>) { |
| 284 | + for stmt in &block.body { |
| 285 | + self.visit_statement(stmt); |
| 286 | + if self.found_assertion { |
| 287 | + return; |
296 | 288 | } |
297 | 289 | } |
298 | | - _ => false, |
299 | | - }) |
| 290 | + } |
| 291 | + |
| 292 | + fn visit_if_statement(&mut self, if_stmt: &oxc_ast::ast::IfStatement<'a>) { |
| 293 | + if let Statement::BlockStatement(block_stmt) = &if_stmt.consequent { |
| 294 | + self.visit_block_statement(block_stmt); |
| 295 | + } |
| 296 | + if self.found_assertion { |
| 297 | + return; |
| 298 | + } |
| 299 | + if let Some(alternate) = &if_stmt.alternate { |
| 300 | + self.visit_statement(alternate); |
| 301 | + } |
| 302 | + } |
| 303 | + |
| 304 | + fn visit_function(&mut self, _func: &Function<'a>, _flags: ScopeFlags) {} |
| 305 | + |
| 306 | + fn visit_formal_parameter(&mut self, _param: &FormalParameter<'a>) {} |
300 | 307 | } |
301 | 308 |
|
302 | 309 | /// Checks if node names returned by getNodeName matches any of the given star patterns |
|
0 commit comments