Skip to content

Commit

Permalink
reimplement 91X in libcst
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Feb 27, 2023
1 parent d8325ad commit df51441
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 232 deletions.
33 changes: 15 additions & 18 deletions flake8_trio/visitors/flake8triovisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any, Union

import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider

from ..base import Error, Statement
Expand Down Expand Up @@ -47,8 +46,6 @@ def __init__(self, shared_state: SharedState):
"typed_calls",
}

self.suppress_errors = False

# `variables` can be saved/loaded, but need a setter to not clear the reference
@property
def variables(self) -> dict[str, str]:
Expand Down Expand Up @@ -105,17 +102,16 @@ def error(
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
return

if not self.suppress_errors:
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
self.__state.problems.append(
Error(
# 7 == len('TRIO...'), so alt messages raise the original code
error_code[:7],
node.lineno,
node.col_offset,
self.error_codes[error_code],
*args,
)
)

def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
if not attrs:
Expand All @@ -131,8 +127,6 @@ def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:

def set_state(self, attrs: dict[str, Any], copy: bool = False):
for attr, value in attrs.items():
if copy and hasattr(value, "copy"):
value = value.copy()
setattr(self, attr, value)

def save_state(self, node: ast.AST, *attrs: str, copy: bool = False):
Expand Down Expand Up @@ -163,14 +157,14 @@ def add_library(self, name: str) -> None:
self.__state.library = self.__state.library + (name,)


class Flake8TrioVisitor_cst(m.MatcherDecoratableTransformer, ABC):
class Flake8TrioVisitor_cst(cst.CSTTransformer, ABC):
# abstract attribute by not providing a value
error_codes: dict[str, str] # pyright: reportUninitializedInstanceVariable=false
METADATA_DEPENDENCIES = (PositionProvider,)

def __init__(self, shared_state: SharedState):
super().__init__()
self.outer: dict[cst.BaseStatement, dict[str, Any]] = {}
self.outer: dict[cst.CSTNode, dict[str, Any]] = {}
self.__state = shared_state

self.options = self.__state.options
Expand All @@ -193,7 +187,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
value = value.copy()
setattr(self, attr, value)

def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
state = self.get_state(*attrs, copy=copy)
if node in self.outer:
# not currently used, and not gonna bother adding dedicated test
Expand All @@ -202,6 +196,9 @@ def save_state(self, node: cst.BaseStatement, *attrs: str, copy: bool = False):
else:
self.outer[node] = state

def restore_state(self, node: cst.CSTNode):
self.set_state(self.outer.pop(node, {}))

def error(
self,
node: cst.CSTNode,
Expand Down
77 changes: 75 additions & 2 deletions flake8_trio/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import ast
from fnmatch import fnmatch
from typing import TYPE_CHECKING, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar

import libcst as cst
import libcst.matchers as m
from libcst.helpers import ensure_type, get_full_name_for_node_or_raise

from ..base import Statement
from . import (
Expand All @@ -27,6 +28,7 @@

T = TypeVar("T", bound=Flake8TrioVisitor)
T_CST = TypeVar("T_CST", bound=Flake8TrioVisitor_cst)
T_EITHER = TypeVar("T_EITHER", Flake8TrioVisitor, Flake8TrioVisitor_cst)


def error_class(error_class: type[T]) -> type[T]:
Expand All @@ -41,7 +43,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]:
return error_class


def disabled_by_default(error_class: type[T]) -> type[T]:
def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]:
assert error_class.error_codes
default_disabled_error_codes.extend(error_class.error_codes)
return error_class
Expand Down Expand Up @@ -86,6 +88,19 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N
return None


def fnmatch_qualified_name_cst(
name_list: Iterable[cst.Decorator], *patterns: str
) -> str | None:
for name in name_list:
qualified_name = get_full_name_for_node_or_raise(name)

for pattern in patterns:
# strip leading "@"s for when we're working with decorators
if fnmatch(qualified_name, pattern.lstrip("@")):
return pattern
return None


# used in 103/104 and 910/911
def iter_guaranteed_once(iterable: ast.expr) -> bool:
# static container with an "elts" attribute
Expand All @@ -112,6 +127,8 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
return True
else:
return True
return False

# check for range() with literal parameters
if (
isinstance(iterable, ast.Call)
Expand All @@ -125,6 +142,62 @@ def iter_guaranteed_once(iterable: ast.expr) -> bool:
return False


def cst_literal_eval(node: cst.BaseExpression) -> Any:
ast_node = cst.Module([cst.SimpleStatementLine([cst.Expr(node)])]).code
try:
return ast.literal_eval(ast_node)
except Exception: # noqa: PIE786
return None


def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool:
# static container with an "elts" attribute
if elts := getattr(iterable, "elements", []):
for elt in elts:
assert isinstance(
elt,
(
cst.Element,
cst.DictElement,
cst.StarredElement,
cst.StarredDictElement,
),
)
# recurse starred expression
if isinstance(elt, (cst.StarredElement, cst.StarredDictElement)):
if iter_guaranteed_once_cst(elt.value):
return True
else:
return True
return False

if isinstance(iterable, cst.SimpleString):
return len(ast.literal_eval(iterable.value)) > 0

# check for range() with literal parameters
if m.matches(
iterable,
m.Call(
func=m.Name("range"),
),
):
try:
return (
len(
range(
*[
cst_literal_eval(a.value)
for a in ensure_type(iterable, cst.Call).args
]
)
)
> 0
)
except Exception: # noqa: PIE786
return False
return False


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Statement | None:
def has_exception(node: ast.expr) -> str | None:
Expand Down
25 changes: 15 additions & 10 deletions flake8_trio/visitors/visitor100.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
the timeout can only be triggered by a checkpoint.
Checkpoints on Await, Async For and Async With
"""
# if future annotations are imported then shed will reformat away the Union use
from typing import Any, Union
from __future__ import annotations

from typing import Any

import libcst as cst
import libcst.matchers as m
Expand All @@ -29,9 +30,13 @@ def __init__(self, *args: Any, **kwargs: Any):
self.has_checkpoint_stack: list[bool] = []
self.node_dict: dict[cst.With, list[AttributeCall]] = {}

def checkpoint(self) -> None:
if self.has_checkpoint_stack:
self.has_checkpoint_stack[-1] = True

def visit_With(self, node: cst.With) -> None:
if m.matches(node, m.With(asynchronous=m.Asynchronous())):
self.checkpoint_node(node)
self.checkpoint()
if res := with_has_call(
node, "fail_after", "fail_at", "move_on_after", "move_on_at", "CancelScope"
):
Expand All @@ -49,12 +54,12 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.Wit
# then: remove the with and pop out it's body
return updated_node

@m.visit(m.Await() | m.For(asynchronous=m.Asynchronous()))
# can't use m.call_if_inside(m.With), since it matches parents *or* the node itself
# need to use Union due to https://github.com/Instagram/LibCST/issues/870
def checkpoint_node(self, node: Union[cst.Await, cst.For, cst.With]):
if self.has_checkpoint_stack:
self.has_checkpoint_stack[-1] = True
def visit_For(self, node: cst.For):
if node.asynchronous is not None:
self.checkpoint()

def visit_Await(self, node: cst.Await | cst.For | cst.With):
self.checkpoint()

def visit_FunctionDef(self, node: cst.FunctionDef):
self.save_state(node, "has_checkpoint_stack", copy=True)
Expand All @@ -63,5 +68,5 @@ def visit_FunctionDef(self, node: cst.FunctionDef):
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node
15 changes: 8 additions & 7 deletions flake8_trio/visitors/visitor101.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
"""
from __future__ import annotations

from typing import Any

import libcst as cst
import libcst.matchers as m
from typing import TYPE_CHECKING, Any

from .flake8triovisitor import Flake8TrioVisitor_cst
from .helpers import (
Expand All @@ -18,6 +15,9 @@
with_has_call,
)

if TYPE_CHECKING:
import libcst as cst


@error_class_cst
class Visitor101(Flake8TrioVisitor_cst):
Expand Down Expand Up @@ -45,13 +45,14 @@ def visit_With(self, node: cst.With):
and bool(with_has_call(node, "open_nursery", *cancel_scope_names))
)

@m.leave(m.OneOf(m.With(), m.FunctionDef()))
def restore_state(
def leave_With(
self, original_node: cst.BaseStatement, updated_node: cst.BaseStatement
) -> cst.BaseStatement:
self.set_state(self.outer.pop(original_node, {}))
self.restore_state(original_node)
return updated_node

leave_FunctionDef = leave_With

def visit_FunctionDef(self, node: cst.FunctionDef):
self.save_state(node, "_yield_is_error", "_safe_decorator")
self._yield_is_error = False
Expand Down
Loading

0 comments on commit df51441

Please sign in to comment.