Skip to content

Commit

Permalink
Running dataclass transform in a later pass to fix crashes (#12762)
Browse files Browse the repository at this point in the history
The dataclass plugin could crash if it encountered a placeholder. Fix the issue by
running the plugin after the main semantic analysis pass, when all placeholders have
been resolved.

Also add a new hook called get_class_decorator_hook_2 that is used by the
dataclass plugin.

We may want to do a similar change to the attrs plugin, but let's change one thing
at a time.

Fix #12685.
  • Loading branch information
JukkaL committed May 11, 2022
1 parent e1c03ab commit 03901ef
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 37 deletions.
32 changes: 30 additions & 2 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,33 @@ def get_class_decorator_hook(self, fullname: str
The plugin can modify a TypeInfo _in place_ (for example add some generated
methods to the symbol table). This hook is called after the class body was
semantically analyzed.
semantically analyzed, but *there may still be placeholders* (typically
caused by forward references).
The hook is called with full names of all class decorators, for example
NOTE: Usually get_class_decorator_hook_2 is the better option, since it
guarantees that there are no placeholders.
The hook is called with full names of all class decorators.
The hook can be called multiple times per class, so it must be
idempotent.
"""
return None

def get_class_decorator_hook_2(self, fullname: str
) -> Optional[Callable[[ClassDefContext], bool]]:
"""Update class definition for given class decorators.
Similar to get_class_decorator_hook, but this runs in a later pass when
placeholders have been resolved.
The hook can return False if some base class hasn't been
processed yet using class hooks. It causes all class hooks
(that are run in this same pass) to be invoked another time for
the file(s) currently being processed.
The hook can be called multiple times per class, so it must be
idempotent.
"""
return None

Expand Down Expand Up @@ -815,6 +839,10 @@ def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname))

def get_class_decorator_hook_2(self, fullname: str
) -> Optional[Callable[[ClassDefContext], bool]]:
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook_2(fullname))

def get_metaclass_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_metaclass_hook(fullname))
Expand Down
47 changes: 35 additions & 12 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,19 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:


class DataclassTransformer:
"""Implement the behavior of @dataclass.
Note that this may be executed multiple times on the same class, so
everything here must be idempotent.
This runs after the main semantic analysis pass, so you can assume that
there are no placeholders.
"""

def __init__(self, ctx: ClassDefContext) -> None:
self._ctx = ctx

def transform(self) -> None:
def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying
dataclass so as to ensure it is fully type checked according
to the rules in PEP 557.
Expand All @@ -119,12 +128,11 @@ def transform(self) -> None:
info = self._ctx.cls.info
attributes = self.collect_attributes()
if attributes is None:
# Some definitions are not ready, defer() should be already called.
return
# Some definitions are not ready. We need another pass.
return False
for attr in attributes:
if attr.type is None:
ctx.api.defer()
return
return False
decorator_arguments = {
'init': _get_decorator_bool_argument(self._ctx, 'init', True),
'eq': _get_decorator_bool_argument(self._ctx, 'eq', True),
Expand Down Expand Up @@ -236,6 +244,8 @@ def transform(self) -> None:
'frozen': decorator_arguments['frozen'],
}

return True

def add_slots(self,
info: TypeInfo,
attributes: List[DataclassAttribute],
Expand Down Expand Up @@ -294,6 +304,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
b: SomeOtherType = ...
are collected.
Return None if some dataclass base class hasn't been processed
yet and thus we'll need to ask for another pass.
"""
# First, collect attributes belonging to the current class.
ctx = self._ctx
Expand All @@ -315,14 +328,11 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:

sym = cls.info.names.get(lhs.name)
if sym is None:
# This name is likely blocked by a star import. We don't need to defer because
# defer() is already called by mark_incomplete().
# There was probably a semantic analysis error.
continue

node = sym.node
if isinstance(node, PlaceholderNode):
# This node is not ready yet.
return None
assert not isinstance(node, PlaceholderNode)
assert isinstance(node, Var)

# x: ClassVar[int] is ignored by dataclasses.
Expand Down Expand Up @@ -390,6 +400,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass_tag' in info.metadata and 'dataclass' not in info.metadata:
# We haven't processed the base class yet. Need another pass.
return None
if 'dataclass' not in info.metadata:
continue

Expand Down Expand Up @@ -517,11 +530,21 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
)


def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
def dataclass_tag_callback(ctx: ClassDefContext) -> None:
"""Record that we have a dataclass in the main semantic analysis pass.
The later pass implemented by DataclassTransformer will use this
to detect dataclasses in base classes.
"""
# The value is ignored, only the existence matters.
ctx.cls.info.metadata['dataclass_tag'] = {}


def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
"""Hooks into the class typechecking process to add support for dataclasses.
"""
transformer = DataclassTransformer(ctx)
transformer.transform()
return transformer.transform()


def _collect_field_args(expr: Expression,
Expand Down
11 changes: 10 additions & 1 deletion mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,21 @@ def get_class_decorator_hook(self, fullname: str
auto_attribs_default=None,
)
elif fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
return dataclasses.dataclass_tag_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback

return None

def get_class_decorator_hook_2(self, fullname: str
) -> Optional[Callable[[ClassDefContext], bool]]:
from mypy.plugins import dataclasses

if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback

return None


def contextmanager_callback(ctx: FunctionContext) -> Type:
"""Infer a better return type for 'contextlib.contextmanager'."""
Expand Down
37 changes: 19 additions & 18 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,43 +1234,44 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:

def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
"""Apply a plugin hook that may infer a more precise definition for a class."""
def get_fullname(expr: Expression) -> Optional[str]:
if isinstance(expr, CallExpr):
return get_fullname(expr.callee)
elif isinstance(expr, IndexExpr):
return get_fullname(expr.base)
elif isinstance(expr, RefExpr):
if expr.fullname:
return expr.fullname
# If we don't have a fullname look it up. This happens because base classes are
# analyzed in a different manner (see exprtotype.py) and therefore those AST
# nodes will not have full names.
sym = self.lookup_type_node(expr)
if sym:
return sym.fullname
return None

for decorator in defn.decorators:
decorator_name = get_fullname(decorator)
decorator_name = self.get_fullname_for_hook(decorator)
if decorator_name:
hook = self.plugin.get_class_decorator_hook(decorator_name)
if hook:
hook(ClassDefContext(defn, decorator, self))

if defn.metaclass:
metaclass_name = get_fullname(defn.metaclass)
metaclass_name = self.get_fullname_for_hook(defn.metaclass)
if metaclass_name:
hook = self.plugin.get_metaclass_hook(metaclass_name)
if hook:
hook(ClassDefContext(defn, defn.metaclass, self))

for base_expr in defn.base_type_exprs:
base_name = get_fullname(base_expr)
base_name = self.get_fullname_for_hook(base_expr)
if base_name:
hook = self.plugin.get_base_class_hook(base_name)
if hook:
hook(ClassDefContext(defn, base_expr, self))

def get_fullname_for_hook(self, expr: Expression) -> Optional[str]:
if isinstance(expr, CallExpr):
return self.get_fullname_for_hook(expr.callee)
elif isinstance(expr, IndexExpr):
return self.get_fullname_for_hook(expr.base)
elif isinstance(expr, RefExpr):
if expr.fullname:
return expr.fullname
# If we don't have a fullname look it up. This happens because base classes are
# analyzed in a different manner (see exprtotype.py) and therefore those AST
# nodes will not have full names.
sym = self.lookup_type_node(expr)
if sym:
return sym.fullname
return None

def analyze_class_keywords(self, defn: ClassDef) -> None:
for value in defn.keywords.values():
value.accept(self)
Expand Down
59 changes: 56 additions & 3 deletions mypy/semanal_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from mypy.checker import FineGrainedDeferredNode
from mypy.server.aststrip import SavedAttributes
from mypy.util import is_typeshed_file
from mypy.options import Options
from mypy.plugin import ClassDefContext
import mypy.build

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,6 +84,8 @@ def semantic_analysis_for_scc(graph: 'Graph', scc: List[str], errors: Errors) ->
apply_semantic_analyzer_patches(patches)
# This pass might need fallbacks calculated above.
check_type_arguments(graph, scc, errors)
# Run class decorator hooks (they requite complete MROs and no placeholders).
apply_class_plugin_hooks(graph, scc, errors)
calculate_class_properties(graph, scc, errors)
check_blockers(graph, scc)
# Clean-up builtins, so that TypeVar etc. are not accessible without importing.
Expand Down Expand Up @@ -132,6 +136,7 @@ def semantic_analysis_for_targets(

check_type_arguments_in_targets(nodes, state, state.manager.errors)
calculate_class_properties(graph, [state.id], state.manager.errors)
apply_class_plugin_hooks(graph, [state.id], state.manager.errors)


def restore_saved_attrs(saved_attrs: SavedAttributes) -> None:
Expand Down Expand Up @@ -382,14 +387,62 @@ def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], stat
target.node.accept(analyzer)


def apply_class_plugin_hooks(graph: 'Graph', scc: List[str], errors: Errors) -> None:
"""Apply class plugin hooks within a SCC.
We run these after to the main semantic analysis so that the hooks
don't need to deal with incomplete definitions such as placeholder
types.
Note that some hooks incorrectly run during the main semantic
analysis pass, for historical reasons.
"""
num_passes = 0
incomplete = True
# If we encounter a base class that has not been processed, we'll run another
# pass. This should eventually reach a fixed point.
while incomplete:
assert num_passes < 10, "Internal error: too many class plugin hook passes"
num_passes += 1
incomplete = False
for module in scc:
state = graph[module]
tree = state.tree
assert tree
for _, node, _ in tree.local_definitions():
if isinstance(node.node, TypeInfo):
if not apply_hooks_to_class(state.manager.semantic_analyzer,
module, node.node, state.options, tree, errors):
incomplete = True


def apply_hooks_to_class(self: SemanticAnalyzer,
module: str,
info: TypeInfo,
options: Options,
file_node: MypyFile,
errors: Errors) -> bool:
# TODO: Move more class-related hooks here?
defn = info.defn
ok = True
for decorator in defn.decorators:
with self.file_context(file_node, options, info):
decorator_name = self.get_fullname_for_hook(decorator)
if decorator_name:
hook = self.plugin.get_class_decorator_hook_2(decorator_name)
if hook:
ok = ok and hook(ClassDefContext(defn, decorator, self))
return ok


def calculate_class_properties(graph: 'Graph', scc: List[str], errors: Errors) -> None:
for module in scc:
tree = graph[module].tree
state = graph[module]
tree = state.tree
assert tree
for _, node, _ in tree.local_definitions():
if isinstance(node.node, TypeInfo):
saved = (module, node.node, None) # module, class, function
with errors.scope.saved_scope(saved) if errors.scope else nullcontext():
with state.manager.semantic_analyzer.file_context(tree, state.options, node.node):
calculate_class_abstract_status(node.node, tree.is_stub, errors)
check_protocol_status(node.node, errors)
calculate_class_vars(node.node)
Expand Down
Loading

0 comments on commit 03901ef

Please sign in to comment.