Skip to content

Commit

Permalink
Ignore NPY201 inside except blocks for compatibility with older n…
Browse files Browse the repository at this point in the history
…umpy versions (#12490)
  • Loading branch information
AlexWaygood authored Jul 24, 2024
1 parent e52be09 commit 928ffd6
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 30 deletions.
17 changes: 17 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/numpy/NPY201.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,20 @@ def func():
np.lookfor

np.NAN

try:
from numpy.lib.npyio import DataSource
except ImportError:
from numpy import DataSource

DataSource("foo").abspath() # fine (`except ImportError` branch)

try:
from numpy.rec import format_parser
from numpy import clongdouble
except ModuleNotFoundError:
from numpy import format_parser
from numpy import longcomplex as clongdouble

format_parser("foo") # fine (`except ModuleNotFoundError` branch)
clongdouble(42) # fine (`except ModuleNotFoundError` branch)
18 changes: 18 additions & 0 deletions crates/ruff_linter/resources/test/fixtures/numpy/NPY201_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,21 @@ def func():
np.ComplexWarning

np.compare_chararrays

try:
np.all([True, True])
except TypeError:
np.alltrue([True, True]) # Should emit a warning here (`except TypeError`, not `except AttributeError`)

try:
np.anyyyy([True, True])
except AttributeError:
np.sometrue([True, True]) # Should emit a warning here
# (must have an attribute access of the undeprecated name in the `try` body for it to be ignored)

try:
exc = np.exceptions.ComplexWarning
except AttributeError:
exc = np.ComplexWarning # `except AttributeError` means that this is okay

raise exc
44 changes: 18 additions & 26 deletions crates/ruff_linter/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ use log::debug;

use ruff_diagnostics::{Diagnostic, IsolationLevel};
use ruff_notebook::{CellOffsets, NotebookIndex};
use ruff_python_ast::helpers::{
collect_import_from_member, extract_handled_exceptions, is_docstring_stmt, to_module_path,
};
use ruff_python_ast::helpers::{collect_import_from_member, is_docstring_stmt, to_module_path};
use ruff_python_ast::identifier::Identifier;
use ruff_python_ast::name::QualifiedName;
use ruff_python_ast::str::Quote;
Expand Down Expand Up @@ -834,32 +832,22 @@ impl<'a> Visitor<'a> for Checker<'a> {
self.semantic.pop_scope();
self.visit_expr(name);
}
Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
}) => {
let mut handled_exceptions = Exceptions::empty();
for type_ in extract_handled_exceptions(handlers) {
if let Some(builtins_name) = self.semantic.resolve_builtin_symbol(type_) {
match builtins_name {
"NameError" => handled_exceptions |= Exceptions::NAME_ERROR,
"ModuleNotFoundError" => {
handled_exceptions |= Exceptions::MODULE_NOT_FOUND_ERROR;
}
"ImportError" => handled_exceptions |= Exceptions::IMPORT_ERROR,
_ => {}
}
}
}

Stmt::Try(
try_node @ ast::StmtTry {
body,
handlers,
orelse,
finalbody,
..
},
) => {
// Iterate over the `body`, then the `handlers`, then the `orelse`, then the
// `finalbody`, but treat the body and the `orelse` as a single branch for
// flow analysis purposes.
let branch = self.semantic.push_branch();
self.semantic.handled_exceptions.push(handled_exceptions);
self.semantic
.handled_exceptions
.push(Exceptions::from_try_stmt(try_node, &self.semantic));
self.visit_body(body);
self.semantic.handled_exceptions.pop();
self.semantic.pop_branch();
Expand Down Expand Up @@ -1837,7 +1825,7 @@ impl<'a> Checker<'a> {
name: &'a str,
range: TextRange,
kind: BindingKind<'a>,
flags: BindingFlags,
mut flags: BindingFlags,
) -> BindingId {
// Determine the scope to which the binding belongs.
// Per [PEP 572](https://peps.python.org/pep-0572/#scope-of-the-target), named
Expand All @@ -1853,6 +1841,10 @@ impl<'a> Checker<'a> {
self.semantic.scope_id
};

if self.semantic.in_exception_handler() {
flags |= BindingFlags::IN_EXCEPT_HANDLER;
}

// Create the `Binding`.
let binding_id = self.semantic.push_binding(range, kind, flags);

Expand Down
241 changes: 239 additions & 2 deletions crates/ruff_linter/src/rules/numpy/rules/numpy_2_0_deprecation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::Expr;
use ruff_python_semantic::Modules;
use ruff_python_ast::name::{QualifiedName, QualifiedNameBuilder};
use ruff_python_ast::statement_visitor::StatementVisitor;
use ruff_python_ast::visitor::Visitor;
use ruff_python_ast::{self as ast, Expr};
use ruff_python_semantic::{Exceptions, Modules, SemanticModel};
use ruff_text_size::Ranged;

use crate::checkers::ast::Checker;
Expand Down Expand Up @@ -665,6 +668,10 @@ pub(crate) fn numpy_2_0_deprecation(checker: &mut Checker, expr: &Expr) {
_ => return,
};

if is_guarded_by_try_except(expr, &replacement, semantic) {
return;
}

let mut diagnostic = Diagnostic::new(
Numpy2Deprecation {
existing: replacement.existing.to_string(),
Expand Down Expand Up @@ -701,3 +708,233 @@ pub(crate) fn numpy_2_0_deprecation(checker: &mut Checker, expr: &Expr) {
};
checker.diagnostics.push(diagnostic);
}

/// Ignore attempts to access a `numpy` member via its deprecated name
/// if the access takes place in an `except` block that provides compatibility
/// with older numpy versions.
///
/// For attribute accesses (e.g. `np.ComplexWarning`), we only ignore the violation
/// if it's inside an `except AttributeError` block, and the member is accessed
/// through its non-deprecated name in the associated `try` block.
///
/// For uses of the `numpy` member where it's simply an `ExprName` node,
/// we check to see how the `numpy` member was bound. If it was bound via a
/// `from numpy import foo` statement, we check to see if that import statement
/// took place inside an `except ImportError` or `except ModuleNotFoundError` block.
/// If so, and if the `numpy` member was imported through its non-deprecated name
/// in the associated try block, we ignore the violation in the same way.
///
/// Examples:
///
/// ```py
/// import numpy as np
///
/// try:
/// np.all([True, True])
/// except AttributeError:
/// np.alltrue([True, True]) # Okay
///
/// try:
/// from numpy.exceptions import ComplexWarning
/// except ImportError:
/// from numpy import ComplexWarning
///
/// x = ComplexWarning() # Okay
/// ```
fn is_guarded_by_try_except(
expr: &Expr,
replacement: &Replacement,
semantic: &SemanticModel,
) -> bool {
match expr {
Expr::Attribute(_) => {
if !semantic.in_exception_handler() {
return false;
}
let Some(try_node) = semantic
.current_statements()
.find_map(|stmt| stmt.as_try_stmt())
else {
return false;
};
let suspended_exceptions = Exceptions::from_try_stmt(try_node, semantic);
if !suspended_exceptions.contains(Exceptions::ATTRIBUTE_ERROR) {
return false;
}
try_block_contains_undeprecated_attribute(try_node, &replacement.details, semantic)
}
Expr::Name(ast::ExprName { id, .. }) => {
let Some(binding_id) = semantic.lookup_symbol(id.as_str()) else {
return false;
};
let binding = semantic.binding(binding_id);
if !binding.is_external() {
return false;
}
if !binding.in_exception_handler() {
return false;
}
let Some(try_node) = binding.source.and_then(|import_id| {
semantic
.statements(import_id)
.find_map(|stmt| stmt.as_try_stmt())
}) else {
return false;
};
let suspended_exceptions = Exceptions::from_try_stmt(try_node, semantic);
if !suspended_exceptions
.intersects(Exceptions::IMPORT_ERROR | Exceptions::MODULE_NOT_FOUND_ERROR)
{
return false;
}
try_block_contains_undeprecated_import(try_node, &replacement.details)
}
_ => false,
}
}

/// Given an [`ast::StmtTry`] node, does the `try` branch of that node
/// contain any [`ast::ExprAttribute`] nodes that indicate the numpy
/// member is being accessed from the non-deprecated location?
fn try_block_contains_undeprecated_attribute(
try_node: &ast::StmtTry,
replacement_details: &Details,
semantic: &SemanticModel,
) -> bool {
let Details::AutoImport {
path,
name,
compatibility: _,
} = replacement_details
else {
return false;
};
let undeprecated_qualified_name = {
let mut builder = QualifiedNameBuilder::default();
for part in path.split('.') {
builder.push(part);
}
builder.push(name);
builder.build()
};
let mut attribute_searcher = AttributeSearcher::new(undeprecated_qualified_name, semantic);
attribute_searcher.visit_body(&try_node.body);
attribute_searcher.found_attribute
}

/// AST visitor that searches an AST tree for [`ast::ExprAttribute`] nodes
/// that match a certain [`QualifiedName`].
struct AttributeSearcher<'a> {
attribute_to_find: QualifiedName<'a>,
semantic: &'a SemanticModel<'a>,
found_attribute: bool,
}

impl<'a> AttributeSearcher<'a> {
fn new(attribute_to_find: QualifiedName<'a>, semantic: &'a SemanticModel<'a>) -> Self {
Self {
attribute_to_find,
semantic,
found_attribute: false,
}
}
}

impl Visitor<'_> for AttributeSearcher<'_> {
fn visit_expr(&mut self, expr: &'_ Expr) {
if self.found_attribute {
return;
}
if expr.is_attribute_expr()
&& self
.semantic
.resolve_qualified_name(expr)
.is_some_and(|qualified_name| qualified_name == self.attribute_to_find)
{
self.found_attribute = true;
return;
}
ast::visitor::walk_expr(self, expr);
}

fn visit_stmt(&mut self, stmt: &ruff_python_ast::Stmt) {
if !self.found_attribute {
ast::visitor::walk_stmt(self, stmt);
}
}

fn visit_body(&mut self, body: &[ruff_python_ast::Stmt]) {
for stmt in body {
self.visit_stmt(stmt);
if self.found_attribute {
return;
}
}
}
}

/// Given an [`ast::StmtTry`] node, does the `try` branch of that node
/// contain any [`ast::StmtImportFrom`] nodes that indicate the numpy
/// member is being imported from the non-deprecated location?
fn try_block_contains_undeprecated_import(
try_node: &ast::StmtTry,
replacement_details: &Details,
) -> bool {
let Details::AutoImport {
path,
name,
compatibility: _,
} = replacement_details
else {
return false;
};
let mut import_searcher = ImportSearcher::new(path, name);
import_searcher.visit_body(&try_node.body);
import_searcher.found_import
}

/// AST visitor that searches an AST tree for [`ast::StmtImportFrom`] nodes
/// that match a certain [`QualifiedName`].
struct ImportSearcher<'a> {
module: &'a str,
name: &'a str,
found_import: bool,
}

impl<'a> ImportSearcher<'a> {
fn new(module: &'a str, name: &'a str) -> Self {
Self {
module,
name,
found_import: false,
}
}
}

impl StatementVisitor<'_> for ImportSearcher<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
if self.found_import {
return;
}
if let ast::Stmt::ImportFrom(ast::StmtImportFrom { module, names, .. }) = stmt {
if module.as_ref().is_some_and(|module| module == self.module)
&& names
.iter()
.any(|ast::Alias { name, .. }| name == self.name)
{
self.found_import = true;
return;
}
}
ast::statement_visitor::walk_stmt(self, stmt);
}

fn visit_body(&mut self, body: &[ruff_python_ast::Stmt]) {
for stmt in body {
self.visit_stmt(stmt);
if self.found_import {
return;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ NPY201.py:72:5: NPY201 [*] `np.NAN` will be removed in NumPy 2.0. Use `numpy.nan
71 |
72 | np.NAN
| ^^^^^^ NPY201
73 |
74 | try:
|
= help: Replace with `numpy.nan`

Expand All @@ -579,3 +581,6 @@ NPY201.py:72:5: NPY201 [*] `np.NAN` will be removed in NumPy 2.0. Use `numpy.nan
71 71 |
72 |- np.NAN
72 |+ np.nan
73 73 |
74 74 | try:
75 75 | from numpy.lib.npyio import DataSource
Loading

0 comments on commit 928ffd6

Please sign in to comment.