Skip to content

Commit

Permalink
RFC: Add Class Decorator/Metaclass/Base Class plugin
Browse files Browse the repository at this point in the history
Also use the plugin to demo adding __init__ to attr.s

Helps python#2088
  • Loading branch information
euresti committed Dec 6, 2017
1 parent 63c656d commit b84d891
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 6 deletions.
98 changes: 94 additions & 4 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Plugin system for extending mypy."""

from collections import OrderedDict
from abc import abstractmethod
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar

from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, \
DictExpr, TypeInfo, ClassDef, ARG_POS, ARG_OPT, Var, Argument, FuncDef, \
Block, SymbolTableNode, MDEF
from mypy.types import (
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType,
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, TypeVarType,
AnyType, TypeList, UnboundType, TypeOfAny
)
from mypy.messages import MessageBuilder
Expand Down Expand Up @@ -53,6 +54,14 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance:
raise NotImplementedError


class SemanticAnalyzerPluginInterface:
"""Interface for accessing semantic analyzer functionality in plugins."""

@abstractmethod
def named_type(self, qualified_name: str, args: Optional[List[Type]] = None) -> Instance:
raise NotImplementedError


# A context for a function hook that infers the return type of a function with
# a special signature.
#
Expand Down Expand Up @@ -98,6 +107,11 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance:
('context', Context),
('api', CheckerPluginInterface)])

ClassDefContext = NamedTuple(
'ClassDecoratorContext', [
('cls', ClassDef),
('api', SemanticAnalyzerPluginInterface)
])

class Plugin:
"""Base class of all type checker plugins.
Expand Down Expand Up @@ -136,7 +150,17 @@ def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
return None

# TODO: metaclass / class decorator hook
def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return None

def get_class_metaclass_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return None

def get_class_base_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return None


T = TypeVar('T')
Expand Down Expand Up @@ -182,6 +206,18 @@ def get_attribute_hook(self, fullname: str
) -> Optional[Callable[[AttributeContext], Type]]:
return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname))

def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname))

def get_class_metaclass_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_class_metaclass_hook(fullname))

def get_class_base_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_class_base_hook(fullname))

def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
for plugin in self._plugins:
hook = lookup(plugin)
Expand Down Expand Up @@ -215,6 +251,11 @@ def get_method_hook(self, fullname: str
return int_pow_callback
return None

def get_class_decorator_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], Type]]:
if fullname == 'attr.s':
return attr_s_callback


def open_callback(ctx: FunctionContext) -> Type:
"""Infer a better return type for 'open'.
Expand Down Expand Up @@ -332,3 +373,52 @@ def int_pow_callback(ctx: MethodContext) -> Type:
else:
return ctx.api.named_generic_type('builtins.float', [])
return ctx.default_return_type


def add_method(
info: TypeInfo,
method_name: str,
args: List[Argument],
ret_type: Type,
self_type: Type,
function_type: Instance) -> None:
from mypy.semanal import set_callable_name

first = [Argument(Var('self'), self_type, None, ARG_POS)]
args = first + args

arg_types = [arg.type_annotation for arg in args]
arg_names = [arg.variable.name() for arg in args]
arg_kinds = [arg.kind for arg in args]
assert None not in arg_types
signature = CallableType(arg_types, arg_kinds, arg_names,
ret_type, function_type)
func = FuncDef(method_name, args, Block([]))
func.info = info
func.is_class = False
func.type = set_callable_name(signature, func)
func._fullname = info.fullname() + '.' + method_name
info.names[method_name] = SymbolTableNode(MDEF, func)


def attr_s_callback(ctx: ClassDefContext) -> None:
"""Add an __init__ method to classes decorated with attr.s."""
info = ctx.cls.info
has_default = {} # TODO: Handle these.
args = []

for name, table in info.names.items():
if table.type:
var = Var(name.lstrip("_"), table.type)
default = has_default.get(var.name(), None)
kind = ARG_POS if default is None else ARG_OPT
args.append(Argument(var, var.type, default, kind))

add_method(
info=info,
method_name='__init__',
args=args,
ret_type=NoneTyp(),
self_type=ctx.api.named_type(info.name()),
function_type=ctx.api.named_type('__builtins__.function'),
)
29 changes: 27 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from mypy.sametypes import is_same_type
from mypy.options import Options
from mypy import experiments
from mypy.plugin import Plugin
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy import join
from mypy.util import get_prefix

Expand Down Expand Up @@ -172,7 +172,7 @@
}


class SemanticAnalyzerPass2(NodeVisitor[None]):
class SemanticAnalyzerPass2(NodeVisitor[None], SemanticAnalyzerPluginInterface):
"""Semantically analyze parsed mypy files.
The analyzer binds names and does various consistency checks for a
Expand Down Expand Up @@ -720,6 +720,31 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]:
self.calculate_abstract_status(defn.info)
self.setup_type_promotion(defn)

for decorator in defn.decorators:
if isinstance(decorator, CallExpr):
fullname = decorator.callee.fullname
else:
fullname = decorator.fullname
hook = self.plugin.get_class_decorator_hook(fullname)
if hook:
hook(ClassDefContext(defn, self))

if defn.metaclass:
metaclass_name = None
if isinstance(defn.metaclass, NameExpr):
metaclass_name = defn.metaclass.name
elif isinstance(defn.metaclass, MemberExpr):
metaclass_name = get_member_expr_fullname(
defn.metaclass)
hook = self.plugin.get_class_metaclass_hook(metaclass_name)
if hook:
hook(ClassDefContext(defn, self))

for type_info in defn.info.bases:
hook = self.plugin.get_class_base_hook(type_info.type.fullname())
if hook:
hook(ClassDefContext(defn, self))

self.leave_class()

def analyze_class_keywords(self, defn: ClassDef) -> None:
Expand Down

0 comments on commit b84d891

Please sign in to comment.