Skip to content

Commit d405634

Browse files
committed
Add plugin hook for extracting types from docstrings
1 parent bc864c5 commit d405634

File tree

10 files changed

+211
-18
lines changed

10 files changed

+211
-18
lines changed

mypy/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def parse_file(self, id: str, path: str, source: str, ignore_errors: bool) -> My
532532
Raise CompileError if there is a parse error.
533533
"""
534534
num_errs = self.errors.num_messages()
535-
tree = parse(source, path, self.errors, options=self.options)
535+
tree = parse(source, path, self.errors, options=self.options, plugin=self.plugin)
536536
tree._fullname = id
537537

538538
if self.errors.num_messages() != num_errs:

mypy/fastparse.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from mypy import messages
3232
from mypy.errors import Errors
3333
from mypy.options import Options
34+
from mypy.plugin import Plugin, DocstringParserHook
3435

3536
try:
3637
from typed_ast import ast3
@@ -58,10 +59,12 @@
5859

5960
TYPE_COMMENT_SYNTAX_ERROR = 'syntax error in type comment'
6061
TYPE_COMMENT_AST_ERROR = 'invalid type comment or annotation'
62+
TYPE_COMMENT_DOCSTRING_ERROR = ('Arguments parsed from docstring are not '
63+
'present in function signature: {} not in {}')
6164

6265

6366
def parse(source: Union[str, bytes], fnam: str = None, errors: Errors = None,
64-
options: Options = Options()) -> MypyFile:
67+
options: Options = Options(), plugin: Plugin = Plugin()) -> MypyFile:
6568

6669
"""Parse a source file, without doing any semantic analysis.
6770
@@ -85,6 +88,7 @@ def parse(source: Union[str, bytes], fnam: str = None, errors: Errors = None,
8588
tree = ASTConverter(options=options,
8689
is_stub=is_stub_file,
8790
errors=errors,
91+
plugin=plugin,
8892
).visit(ast)
8993
tree.path = fnam
9094
tree.is_stub = is_stub_file
@@ -112,6 +116,31 @@ def parse_type_comment(type_comment: str, line: int, errors: Optional[Errors]) -
112116
return TypeConverter(errors, line=line).visit(typ.body)
113117

114118

119+
def parse_docstring(hook: DocstringParserHook, docstring: str, arg_names: List[str],
120+
line: int, errors: Errors) -> Optional[Tuple[List[Type], Type]]:
121+
"""Parse a docstring and return type representations.
122+
123+
Returns a 2-tuple: (list of arguments Types, and return Type).
124+
"""
125+
126+
def pop_and_convert(name: str) -> Optional[Type]:
127+
t = type_map.pop(name, None)
128+
if t is None:
129+
return AnyType()
130+
else:
131+
return parse_type_comment(t[0], t[1], errors)
132+
133+
type_map = hook(docstring, line, errors)
134+
if type_map:
135+
arg_types = [pop_and_convert(name) for name in arg_names]
136+
return_type = pop_and_convert('return')
137+
if type_map:
138+
errors.report(line, 0,
139+
TYPE_COMMENT_DOCSTRING_ERROR.format(list(type_map), arg_names))
140+
return arg_types, return_type
141+
return None
142+
143+
115144
def with_line(f: Callable[['ASTConverter', T], U]) -> Callable[['ASTConverter', T], U]:
116145
@wraps(f)
117146
def wrapper(self: 'ASTConverter', ast: T) -> U:
@@ -141,13 +170,15 @@ class ASTConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931
141170
def __init__(self,
142171
options: Options,
143172
is_stub: bool,
144-
errors: Errors) -> None:
173+
errors: Errors,
174+
plugin: Plugin) -> None:
145175
self.class_nesting = 0
146176
self.imports = [] # type: List[ImportBase]
147177

148178
self.options = options
149179
self.is_stub = is_stub
150180
self.errors = errors
181+
self.plugin = plugin
151182

152183
def fail(self, msg: str, line: int, column: int) -> None:
153184
self.errors.report(line, column, msg)
@@ -302,8 +333,9 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
302333
args = self.transform_args(n.args, n.lineno, no_type_check=no_type_check)
303334

304335
arg_kinds = [arg.kind for arg in args]
305-
arg_names = [arg.variable.name() for arg in args] # type: List[Optional[str]]
306-
arg_names = [None if argument_elide_name(name) else name for name in arg_names]
336+
real_names = [arg.variable.name() for arg in args] # type: List[str]
337+
arg_names = [None if argument_elide_name(name) else name
338+
for name in real_names] # type: List[Optional[str]]
307339
if special_function_elide_names(n.name):
308340
arg_names = [None] * len(arg_names)
309341
arg_types = None # type: List[Type]
@@ -345,6 +377,14 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
345377
return_type = TypeConverter(self.errors, line=n.returns.lineno
346378
if n.returns else n.lineno).visit(n.returns)
347379

380+
docstring_hook = self.plugin.get_docstring_parser_hook()
381+
if docstring_hook is not None and not any(arg_types) and return_type is None:
382+
doc = ast3.get_docstring(n, clean=False)
383+
if doc:
384+
types = parse_docstring(docstring_hook, doc, real_names, n.lineno, self.errors)
385+
if types is not None:
386+
arg_types, return_type = types
387+
348388
for arg, arg_type in zip(args, arg_types):
349389
self.set_type_optional(arg_type, arg.initializer)
350390

mypy/fastparse2.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
from mypy import experiments
4242
from mypy import messages
4343
from mypy.errors import Errors
44-
from mypy.fastparse import TypeConverter, parse_type_comment
44+
from mypy.fastparse import (TypeConverter, parse_type_comment,
45+
parse_docstring)
4546
from mypy.options import Options
47+
from mypy.plugin import Plugin
4648

4749
try:
4850
from typed_ast import ast27
@@ -74,7 +76,7 @@
7476

7577

7678
def parse(source: Union[str, bytes], fnam: str = None, errors: Errors = None,
77-
options: Options = Options()) -> MypyFile:
79+
options: Options = Options(), plugin: Plugin = Plugin()) -> MypyFile:
7880
"""Parse a source file, without doing any semantic analysis.
7981
8082
Return the parse tree. If errors is not provided, raise ParseError
@@ -92,6 +94,7 @@ def parse(source: Union[str, bytes], fnam: str = None, errors: Errors = None,
9294
tree = ASTConverter(options=options,
9395
is_stub=is_stub_file,
9496
errors=errors,
97+
plugin=plugin,
9598
).visit(ast)
9699
assert isinstance(tree, MypyFile)
97100
tree.path = fnam
@@ -135,13 +138,15 @@ class ASTConverter(ast27.NodeTransformer):
135138
def __init__(self,
136139
options: Options,
137140
is_stub: bool,
138-
errors: Errors) -> None:
141+
errors: Errors,
142+
plugin: Plugin) -> None:
139143
self.class_nesting = 0
140144
self.imports = [] # type: List[ImportBase]
141145

142146
self.options = options
143147
self.is_stub = is_stub
144148
self.errors = errors
149+
self.plugin = plugin
145150

146151
def fail(self, msg: str, line: int, column: int) -> None:
147152
self.errors.report(line, column, msg)
@@ -284,8 +289,9 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
284289
args, decompose_stmts = self.transform_args(n.args, n.lineno)
285290

286291
arg_kinds = [arg.kind for arg in args]
287-
arg_names = [arg.variable.name() for arg in args] # type: List[Optional[str]]
288-
arg_names = [None if argument_elide_name(name) else name for name in arg_names]
292+
real_names = [arg.variable.name() for arg in args] # type: List[str]
293+
arg_names = [None if argument_elide_name(name) else name
294+
for name in real_names] # type: List[Optional[str]]
289295
if special_function_elide_names(n.name):
290296
arg_names = [None] * len(arg_names)
291297

@@ -321,6 +327,15 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
321327
arg_types = [a.type_annotation for a in args]
322328
return_type = converter.visit(None)
323329

330+
docstring_hook = self.plugin.get_docstring_parser_hook()
331+
if docstring_hook is not None and not any(arg_types) and return_type is None:
332+
doc = ast27.get_docstring(n, clean=False)
333+
if doc:
334+
types = parse_docstring(docstring_hook, doc.decode('unicode_escape'),
335+
real_names, n.lineno, self.errors)
336+
if types is not None:
337+
arg_types, return_type = types
338+
324339
for arg, arg_type in zip(args, arg_types):
325340
self.set_type_optional(arg_type, arg.initializer)
326341

mypy/parse.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from mypy.errors import Errors
44
from mypy.options import Options
5+
from mypy.plugin import Plugin
56
from mypy.nodes import MypyFile
67

78

89
def parse(source: Union[str, bytes],
910
fnam: str,
1011
errors: Optional[Errors],
11-
options: Options) -> MypyFile:
12+
options: Options,
13+
plugin: Plugin) -> MypyFile:
1214
"""Parse a source file, without doing any semantic analysis.
1315
1416
Return the parse tree. If errors is not provided, raise ParseError
@@ -22,10 +24,12 @@ def parse(source: Union[str, bytes],
2224
return mypy.fastparse.parse(source,
2325
fnam=fnam,
2426
errors=errors,
25-
options=options)
27+
options=options,
28+
plugin=plugin)
2629
else:
2730
import mypy.fastparse2
2831
return mypy.fastparse2.parse(source,
2932
fnam=fnam,
3033
errors=errors,
31-
options=options)
34+
options=options,
35+
plugin=plugin)

mypy/plugin.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar
1+
from typing import Callable, Dict, List, Tuple, Optional, NamedTuple, TypeVar
22

33
from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context
44
from mypy.types import (
55
Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType,
66
AnyType
77
)
88
from mypy.messages import MessageBuilder
9+
from mypy import defaults
910

11+
MYPY = False
12+
if MYPY:
13+
from mypy.errors import Errors
1014

1115
# Create an Instance given full name of class and type arguments.
1216
NamedInstanceCallback = Callable[[str, List[Type]], Type]
@@ -58,6 +62,27 @@
5862
Type # Return type inferred by the callback
5963
]
6064

65+
# A callback for extracting type annotations from docstrings.
66+
#
67+
# Called for each unannotated function that has adocstring.
68+
# The function's return type, if specified, is stored in the mapping with the special
69+
# key 'return'. Other than 'return', each key of the mapping must be one of the
70+
# arguments of the documented function; otherwise, an error will be raised.
71+
DocstringParserHook = Callable[
72+
[
73+
str, # The docstring to be parsed
74+
int, # The line number where the docstring begins
75+
'Errors' # Errors object for reporting errors, warnings, and info
76+
],
77+
Dict[
78+
str, # Argument name, or 'return' for return type.
79+
Tuple[
80+
str, # PEP484-compatible string to use as the argument's type
81+
int # Line number identifying the location of the type annotation
82+
]
83+
]
84+
]
85+
6186

6287
class Plugin:
6388
"""Base class of all type checker plugins.
@@ -69,7 +94,7 @@ class Plugin:
6994
results might be cached).
7095
"""
7196

72-
def __init__(self, python_version: Tuple[int, int]) -> None:
97+
def __init__(self, python_version: Tuple[int, int] = defaults.PYTHON3_VERSION) -> None:
7398
self.python_version = python_version
7499

75100
def get_function_hook(self, fullname: str) -> Optional[FunctionHook]:
@@ -81,6 +106,9 @@ def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHo
81106
def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
82107
return None
83108

109+
def get_docstring_parser_hook(self) -> Optional[DocstringParserHook]:
110+
return None
111+
84112
# TODO: metaclass / class decorator hook
85113

86114

@@ -116,6 +144,9 @@ def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHo
116144
def get_method_hook(self, fullname: str) -> Optional[MethodHook]:
117145
return self._find_hook(lambda plugin: plugin.get_method_hook(fullname))
118146

147+
def get_docstring_parser_hook(self) -> Optional[DocstringParserHook]:
148+
return self._find_hook(lambda plugin: plugin.get_docstring_parser_hook())
149+
119150
def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
120151
for plugin in self._plugins:
121152
hook = lookup(plugin)

mypy/stubgen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module
6464
from mypy.stubutil import is_c_module, write_header
6565
from mypy.options import Options as MypyOptions
66+
from mypy.plugin import Plugin
6667

6768

6869
Options = NamedTuple('Options', [('pyversion', Tuple[int, int]),
@@ -194,8 +195,9 @@ def generate_stub(path: str, output_dir: str, _all_: Optional[List[str]] = None,
194195
source = f.read()
195196
options = MypyOptions()
196197
options.python_version = pyversion
198+
plugin = Plugin(pyversion)
197199
try:
198-
ast = mypy.parse.parse(source, fnam=path, errors=None, options=options)
200+
ast = mypy.parse.parse(source, fnam=path, errors=None, options=options, plugin=plugin)
199201
except mypy.errors.CompileError as e:
200202
# Syntax error!
201203
for m in e.messages:

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
'check-enum.test',
7878
'check-incomplete-fixture.test',
7979
'check-custom-plugin.test',
80+
'check-docstring-hook.test',
8081
]
8182

8283

mypy/test/testparse.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from mypy.parse import parse
1313
from mypy.errors import CompileError
1414
from mypy.options import Options
15+
from mypy.plugin import Plugin
1516

1617

1718
class ParserSuite(Suite):
@@ -43,7 +44,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
4344
n = parse(bytes('\n'.join(testcase.input), 'ascii'),
4445
fnam='main',
4546
errors=None,
46-
options=options)
47+
options=options,
48+
plugin=Plugin(options.python_version))
4749
a = str(n).split('\n')
4850
except CompileError as e:
4951
a = e.messages
@@ -68,7 +70,8 @@ def cases(self) -> List[DataDrivenTestCase]:
6870
def test_parse_error(testcase: DataDrivenTestCase) -> None:
6971
try:
7072
# Compile temporary file. The test file contains non-ASCII characters.
71-
parse(bytes('\n'.join(testcase.input), 'utf-8'), INPUT_FILE_NAME, None, Options())
73+
parse(bytes('\n'.join(testcase.input), 'utf-8'), INPUT_FILE_NAME, None, Options(),
74+
Plugin())
7275
raise AssertionFailure('No errors reported')
7376
except CompileError as e:
7477
# Verify that there was a compile error and that the error messages

0 commit comments

Comments
 (0)