Skip to content

Commit

Permalink
Use referencial equality in traversal helper methods (#13895)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser authored Oct 24, 2024
1 parent de4181d commit e402e27
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub(crate) fn needless_bool(checker: &mut Checker, stmt: &Stmt) {
.semantic()
.current_statement_parent()
.and_then(|parent| traversal::suite(stmt, parent))
.and_then(|suite| traversal::next_sibling(stmt, suite))
.and_then(|suite| suite.next_sibling())
else {
return;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) {
// - `for` loop followed by `return True` or `return False`.
let Some(terminal) = match_else_return(stmt).or_else(|| {
let parent = checker.semantic().current_statement_parent()?;
let suite = traversal::suite(stmt, parent)?;
let sibling = traversal::next_sibling(stmt, suite)?;
let sibling = traversal::suite(stmt, parent)?.next_sibling()?;
match_sibling_return(stmt, sibling)
}) else {
return;
Expand Down
16 changes: 9 additions & 7 deletions crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use rustc_hash::FxHashMap;
use ast::traversal;
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::traversal::EnclosingSuite;
use ruff_python_ast::{self as ast, Expr, Stmt};
use ruff_python_codegen::Generator;
use ruff_python_semantic::analyze::typing::is_list;
Expand Down Expand Up @@ -179,32 +180,33 @@ fn match_consecutive_appends<'a>(

// In order to match consecutive statements, we need to go to the tree ancestor of the
// given statement, find its position there, and match all 'appends' from there.
let siblings: &[Stmt] = if semantic.at_top_level() {
let suite = if semantic.at_top_level() {
// If the statement is at the top level, we should go to the parent module.
// Module is available in the definitions list.
semantic.definitions.python_ast()?
EnclosingSuite::new(semantic.definitions.python_ast()?, stmt)?
} else {
// Otherwise, go to the parent, and take its body as a sequence of siblings.
semantic
.current_statement_parent()
.and_then(|parent| traversal::suite(stmt, parent))?
};

let stmt_index = siblings.iter().position(|sibling| sibling == stmt)?;

// We shouldn't repeat the same work for many 'appends' that go in a row. Let's check
// that this statement is at the beginning of such a group.
if stmt_index != 0 && match_append(semantic, &siblings[stmt_index - 1]).is_some() {
if suite
.previous_sibling()
.is_some_and(|previous_stmt| match_append(semantic, previous_stmt).is_some())
{
return None;
}

// Starting from the next statement, let's match all appends and make a vector.
Some(
std::iter::once(append)
.chain(
siblings
suite
.next_siblings()
.iter()
.skip(stmt_index + 1)
.map_while(|sibling| match_append(semantic, sibling)),
)
.collect(),
Expand Down
110 changes: 55 additions & 55 deletions crates/ruff_python_ast/src/traversal.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,81 @@
//! Utilities for manually traversing a Python AST.
use crate::{self as ast, ExceptHandler, Stmt, Suite};
use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt};

/// Given a [`Stmt`] and its parent, return the [`Suite`] that contains the [`Stmt`].
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> {
/// Given a [`Stmt`] and its parent, return the [`ast::Suite`] that contains the [`Stmt`].
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<EnclosingSuite<'a>> {
// TODO: refactor this to work without a parent, ie when `stmt` is at the top level
match parent {
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body),
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body),
Stmt::For(ast::StmtFor { body, orelse, .. }) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else {
None
}
}
Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else {
None
}
}
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::For(ast::StmtFor { body, orelse, .. }) => [body, orelse]
.iter()
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::While(ast::StmtWhile { body, orelse, .. }) => [body, orelse]
.iter()
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::If(ast::StmtIf {
body,
elif_else_clauses,
..
}) => {
if body.contains(stmt) {
Some(body)
} else {
elif_else_clauses
.iter()
.map(|elif_else_clause| &elif_else_clause.body)
.find(|body| body.contains(stmt))
}
}
Stmt::With(ast::StmtWith { body, .. }) => Some(body),
}) => [body]
.into_iter()
.chain(elif_else_clauses.iter().map(|clause| &clause.body))
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
Stmt::With(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt),
Stmt::Match(ast::StmtMatch { cases, .. }) => cases
.iter()
.map(|case| &case.body)
.find(|body| body.contains(stmt)),
.find_map(|body| EnclosingSuite::new(body, stmt)),
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
if body.contains(stmt) {
Some(body)
} else if orelse.contains(stmt) {
Some(orelse)
} else if finalbody.contains(stmt) {
Some(finalbody)
} else {
}) => [body, orelse, finalbody]
.into_iter()
.chain(
handlers
.iter()
.filter_map(ExceptHandler::as_except_handler)
.map(|handler| &handler.body)
.find(|body| body.contains(stmt))
}
}
.map(|handler| &handler.body),
)
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
_ => None,
}
}

/// Given a [`Stmt`] and its containing [`Suite`], return the next [`Stmt`] in the [`Suite`].
pub fn next_sibling<'a>(stmt: &'a Stmt, suite: &'a Suite) -> Option<&'a Stmt> {
let mut iter = suite.iter();
while let Some(sibling) = iter.next() {
if sibling == stmt {
return iter.next();
}
pub struct EnclosingSuite<'a> {
suite: &'a [Stmt],
position: usize,
}

impl<'a> EnclosingSuite<'a> {
pub fn new(suite: &'a [Stmt], stmt: &'a Stmt) -> Option<Self> {
let position = suite
.iter()
.position(|sibling| AnyNodeRef::ptr_eq(sibling.into(), stmt.into()))?;

Some(EnclosingSuite { suite, position })
}

pub fn next_sibling(&self) -> Option<&'a Stmt> {
self.suite.get(self.position + 1)
}

pub fn next_siblings(&self) -> &'a [Stmt] {
self.suite.get(self.position + 1..).unwrap_or_default()
}

pub fn previous_sibling(&self) -> Option<&'a Stmt> {
self.suite.get(self.position.checked_sub(1)?)
}
}

impl std::ops::Deref for EnclosingSuite<'_> {
type Target = [Stmt];

fn deref(&self) -> &Self::Target {
self.suite
}
None
}

0 comments on commit e402e27

Please sign in to comment.