Skip to content

Commit d781c27

Browse files
committed
reimplement 91X in libcst
1 parent d8325ad commit d781c27

File tree

8 files changed

+471
-214
lines changed

8 files changed

+471
-214
lines changed

flake8_trio/visitors/flake8triovisitor.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ def __init__(self, shared_state: SharedState):
4747
"typed_calls",
4848
}
4949

50-
self.suppress_errors = False
51-
5250
# `variables` can be saved/loaded, but need a setter to not clear the reference
5351
@property
5452
def variables(self) -> dict[str, str]:
@@ -105,17 +103,16 @@ def error(
105103
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
106104
return
107105

108-
if not self.suppress_errors:
109-
self.__state.problems.append(
110-
Error(
111-
# 7 == len('TRIO...'), so alt messages raise the original code
112-
error_code[:7],
113-
node.lineno,
114-
node.col_offset,
115-
self.error_codes[error_code],
116-
*args,
117-
)
106+
self.__state.problems.append(
107+
Error(
108+
# 7 == len('TRIO...'), so alt messages raise the original code
109+
error_code[:7],
110+
node.lineno,
111+
node.col_offset,
112+
self.error_codes[error_code],
113+
*args,
118114
)
115+
)
119116

120117
def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
121118
if not attrs:
@@ -170,7 +167,7 @@ class Flake8TrioVisitor_cst(m.MatcherDecoratableTransformer, ABC):
170167

171168
def __init__(self, shared_state: SharedState):
172169
super().__init__()
173-
self.outer: dict[cst.BaseStatement, dict[str, Any]] = {}
170+
self.outer: dict[cst.CSTNode, dict[str, Any]] = {}
174171
self.__state = shared_state
175172

176173
self.options = self.__state.options
@@ -193,7 +190,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
193190
value = value.copy()
194191
setattr(self, attr, value)
195192

196-
def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
193+
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
197194
state = self.get_state(*attrs, copy=copy)
198195
if node in self.outer:
199196
# not currently used, and not gonna bother adding dedicated test
@@ -202,6 +199,9 @@ def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
202199
else:
203200
self.outer[node] = state
204201

202+
def restore_state(self, node: cst.CSTNode):
203+
self.set_state(self.outer.pop(node, {}))
204+
205205
def error(
206206
self,
207207
node: cst.CSTNode,

flake8_trio/visitors/helpers.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
import ast
99
from fnmatch import fnmatch
10-
from typing import TYPE_CHECKING, NamedTuple, TypeVar
10+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
1111

1212
import libcst as cst
1313
import libcst.matchers as m
14+
from libcst.helpers import ensure_type, get_full_name_for_node_or_raise
1415

1516
from ..base import Statement
1617
from . import (
@@ -27,6 +28,7 @@
2728

2829
T = TypeVar("T", bound=Flake8TrioVisitor)
2930
T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst)
31+
T_EITHER = TypeVar("T_EITHER", Flake8TrioVisitor, Flake8TrioVisitor_cst)
3032

3133

3234
def error_class(error_class: type[T]) -> type[T]:
@@ -41,7 +43,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]:
4143
return error_class
4244

4345

44-
def disabled_by_default(error_class: type[T]) -> type[T]:
46+
def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]:
4547
assert error_class.error_codes
4648
default_disabled_error_codes.extend(error_class.error_codes)
4749
return error_class
@@ -86,6 +88,21 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N
8688
return None
8789

8890

91+
def fnmatch_qualified_name_cst(
92+
name_list: Iterable[cst.Decorator], *patterns: str
93+
) -> str | None:
94+
for name in name_list:
95+
if isinstance(name, cst.Call):
96+
name = name.func
97+
qualified_name = get_full_name_for_node_or_raise(name)
98+
99+
for pattern in patterns:
100+
# strip leading "@"s for when we're working with decorators
101+
if fnmatch(qualified_name, pattern.lstrip("@")):
102+
return pattern
103+
return None
104+
105+
89106
# used in 103/104 and 910/911
90107
def iter_guaranteed_once(iterable: ast.expr) -> bool:
91108
# static container with an "elts" attribute
@@ -125,6 +142,62 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
125142
return False
126143

127144

145+
def cst_literal_eval(node: cst.BaseExpression) -> Any:
146+
ast_node = cst.Module([cst.SimpleStatementLine([cst.Expr(node)])]).code
147+
try:
148+
return ast.literal_eval(ast_node)
149+
except Exception: # noqa: PIE786
150+
return None
151+
152+
153+
def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool:
154+
# static container with an "elts" attribute
155+
if elts := getattr(iterable, "elements", []):
156+
for elt in elts:
157+
assert isinstance(
158+
elt,
159+
(
160+
cst.Element,
161+
cst.DictElement,
162+
cst.StarredElement,
163+
cst.StarredDictElement,
164+
),
165+
)
166+
# recurse starred expression
167+
if isinstance(elt, (cst.StarredElement, cst.StarredDictElement)):
168+
if iter_guaranteed_once_cst(elt.value):
169+
return True
170+
else:
171+
return True
172+
return False
173+
174+
if isinstance(iterable, cst.SimpleString):
175+
return len(ast.literal_eval(iterable.value)) > 0
176+
177+
# check for range() with literal parameters
178+
if m.matches(
179+
iterable,
180+
m.Call(
181+
func=m.Name("range"),
182+
),
183+
):
184+
try:
185+
return (
186+
len(
187+
range(
188+
*[
189+
cst_literal_eval(a.value)
190+
for a in ensure_type(iterable, cst.Call).args
191+
]
192+
)
193+
)
194+
> 0
195+
)
196+
except Exception: # noqa: PIE786
197+
return False
198+
return False
199+
200+
128201
# used in 102, 103 and 104
129202
def critical_except(node: ast.ExceptHandler) -> Statement | None:
130203
def has_exception(node: ast.expr) -> str | None:

flake8_trio/visitors/visitor100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,5 @@ def visit_FunctionDef(self, node: cst.FunctionDef):
6363
def leave_FunctionDef(
6464
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
6565
) -> cst.FunctionDef:
66-
self.set_state(self.outer.pop(original_node, {}))
66+
self.restore_state(original_node)
6767
return updated_node

flake8_trio/visitors/visitor101.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def visit_With(self, node: cst.With):
4646
)
4747

4848
@m.leave(m.OneOf(m.With(), m.FunctionDef()))
49-
def restore_state(
49+
def restore_state_(
5050
self, original_node: cst.BaseStatement, updated_node: cst.BaseStatement
5151
) -> cst.BaseStatement:
52-
self.set_state(self.outer.pop(original_node, {}))
52+
self.restore_state(original_node)
5353
return updated_node
5454

5555
def visit_FunctionDef(self, node: cst.FunctionDef):

0 commit comments

Comments
 (0)