Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions mypy/stubdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,18 @@ def add_token(self, token: tokenize.TokenInfo) -> None:
self.reset()
return
self.keyword_only = len(self.args)
# Add * as an argument
self.args.append(ArgSig(name="*"))
self.accumulator = ""
else:
if self.accumulator.startswith("*"):
self.keyword_only = len(self.args) + 1
self.arg_name = self.accumulator
if not (
token.string == ")" and self.accumulator.strip() == ""
) and not _ARG_NAME_RE.match(self.arg_name):
if (
not (token.string == ")" and self.accumulator.strip() == "")
and not _ARG_NAME_RE.match(self.arg_name)
and self.arg_name not in ("/", "*")
):
# Invalid argument name.
self.reset()
return
Expand All @@ -281,7 +285,7 @@ def add_token(self, token: tokenize.TokenInfo) -> None:
if (
self.state[-1] == STATE_ARGUMENT_LIST
and self.keyword_only is not None
and self.keyword_only == len(self.args)
and self.keyword_only == len(self.args) - 1
and not self.arg_name
):
# Error condition: * must be followed by arguments
Expand Down Expand Up @@ -320,8 +324,8 @@ def add_token(self, token: tokenize.TokenInfo) -> None:
self.reset()
return
self.pos_only = len(self.args)
self.state.append(STATE_ARGUMENT_TYPE)
self.accumulator = ""
# Set accumulator to / so it gets processed like a regular argument
self.accumulator = "/"

elif token.type == tokenize.OP and token.string == "->" and self.state[-1] == STATE_INIT:
self.accumulator = ""
Expand All @@ -347,6 +351,8 @@ def add_token(self, token: tokenize.TokenInfo) -> None:
self.found = False
self.args = []
self.ret_type = "Any"
self.pos_only = None
self.keyword_only = None
# Leave state as INIT.
else:
self.accumulator += token.string
Expand All @@ -356,6 +362,8 @@ def reset(self) -> None:
self.args = []
self.found = False
self.accumulator = ""
self.pos_only = None
self.keyword_only = None

def get_signatures(self) -> list[FunctionSig]:
"""Return sorted copy of the list of signatures found so far."""
Expand Down
140 changes: 126 additions & 14 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,40 @@ def test_infer_sig_from_docstring(self) -> None:
],
)

def test_infer_sig_from_docstring_overloads(self) -> None:

assert_equal(
infer_sig_from_docstring("\nfunc(x: int=3) -> int\nfunc(x: str) -> str", "func"),
[
FunctionSig(
name="func", args=[ArgSig(name="x", type="int", default=True)], ret_type="int"
),
FunctionSig(
name="func", args=[ArgSig(name="x", type="str", default=False)], ret_type="str"
),
],
)

assert_equal(
infer_sig_from_docstring("func(x: foo.bar)\nfunc(x: str) -> foo.bar", "func"),
[
FunctionSig(name="func", args=[ArgSig(name="x", type="foo.bar")], ret_type="Any"),
FunctionSig(name="func", args=[ArgSig(name="x", type="str")], ret_type="foo.bar"),
],
)

assert_equal(
infer_sig_from_docstring(
"\nfunc(x: int=3) -> int\nfunc(x: invalid::type<with_template>)", "func"
),
[
FunctionSig(
name="func", args=[ArgSig(name="x", type="int", default=True)], ret_type="int"
),
FunctionSig(name="func", args=[ArgSig(name="x", type=None)], ret_type="Any"),
],
)

def test_infer_sig_from_docstring_duplicate_args(self) -> None:
assert_equal(
infer_sig_from_docstring("\nfunc(x, x) -> str\nfunc(x, y) -> int", "func"),
Expand Down Expand Up @@ -435,30 +469,36 @@ def test_infer_sig_from_docstring_args_kwargs_errors(self) -> None:
assert_equal(infer_sig_from_docstring("func(**kwargs, *args) -> int", "func"), [])

def test_infer_sig_from_docstring_positional_only_arguments(self) -> None:
assert_equal(
infer_sig_from_docstring("func(self, /) -> str", "func"),
[FunctionSig(name="func", args=[ArgSig(name="self")], ret_type="str")],
)

assert_equal(
infer_sig_from_docstring("func(self, x, /) -> str", "func"),
[
FunctionSig(
name="func", args=[ArgSig(name="self"), ArgSig(name="x")], ret_type="str"
name="func",
args=[ArgSig(name="self"), ArgSig(name="x"), ArgSig(name="/")],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring("func(x, /, y) -> int", "func"),
[FunctionSig(name="func", args=[ArgSig(name="x"), ArgSig(name="y")], ret_type="int")],
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="/"), ArgSig(name="y")],
ret_type="int",
)
],
)

assert_equal(
infer_sig_from_docstring("func(x, /, *args) -> str", "func"),
[
FunctionSig(
name="func", args=[ArgSig(name="x"), ArgSig(name="*args")], ret_type="str"
name="func",
args=[ArgSig(name="x"), ArgSig(name="/"), ArgSig(name="*args")],
ret_type="str",
)
],
)
Expand All @@ -468,51 +508,121 @@ def test_infer_sig_from_docstring_positional_only_arguments(self) -> None:
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="kwonly"), ArgSig(name="**kwargs")],
args=[
ArgSig(name="x"),
ArgSig(name="/"),
ArgSig(name="*"),
ArgSig(name="kwonly"),
ArgSig(name="**kwargs"),
],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring("func(self, /) -> str\nfunc(self, x, /) -> str", "func"),
[
FunctionSig(
name="func", args=[ArgSig(name="self"), ArgSig(name="/")], ret_type="str"
),
FunctionSig(
name="func",
args=[ArgSig(name="self"), ArgSig(name="x"), ArgSig(name="/")],
ret_type="str",
),
],
)

def test_infer_sig_from_docstring_keyword_only_arguments(self) -> None:
assert_equal(
infer_sig_from_docstring("func(*, x) -> str", "func"),
[FunctionSig(name="func", args=[ArgSig(name="x")], ret_type="str")],
[FunctionSig(name="func", args=[ArgSig(name="*"), ArgSig(name="x")], ret_type="str")],
)

assert_equal(
infer_sig_from_docstring("func(x, *, y) -> str", "func"),
[FunctionSig(name="func", args=[ArgSig(name="x"), ArgSig(name="y")], ret_type="str")],
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="*"), ArgSig(name="y")],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring("func(*, x, y) -> str", "func"),
[FunctionSig(name="func", args=[ArgSig(name="x"), ArgSig(name="y")], ret_type="str")],
[
FunctionSig(
name="func",
args=[ArgSig(name="*"), ArgSig(name="x"), ArgSig(name="y")],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring("func(x, *, kwonly, **kwargs) -> str", "func"),
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="kwonly"), ArgSig("**kwargs")],
args=[
ArgSig(name="x"),
ArgSig(name="*"),
ArgSig(name="kwonly"),
ArgSig("**kwargs"),
],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring(
"func(*, x) -> str\nfunc(x, *, kwonly, **kwargs) -> str", "func"
),
[
FunctionSig(
name="func", args=[ArgSig(name="*"), ArgSig(name="x")], ret_type="str"
),
FunctionSig(
name="func",
args=[
ArgSig(name="x"),
ArgSig(name="*"),
ArgSig(name="kwonly"),
ArgSig("**kwargs"),
],
ret_type="str",
),
],
)

def test_infer_sig_from_docstring_pos_only_and_keyword_only_arguments(self) -> None:
assert_equal(
infer_sig_from_docstring("func(x, /, *, y) -> str", "func"),
[FunctionSig(name="func", args=[ArgSig(name="x"), ArgSig(name="y")], ret_type="str")],
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="/"), ArgSig(name="*"), ArgSig(name="y")],
ret_type="str",
)
],
)

assert_equal(
infer_sig_from_docstring("func(x, /, y, *, z) -> str", "func"),
[
FunctionSig(
name="func",
args=[ArgSig(name="x"), ArgSig(name="y"), ArgSig(name="z")],
args=[
ArgSig(name="x"),
ArgSig(name="/"),
ArgSig(name="y"),
ArgSig(name="*"),
ArgSig(name="z"),
],
ret_type="str",
)
],
Expand All @@ -525,7 +635,9 @@ def test_infer_sig_from_docstring_pos_only_and_keyword_only_arguments(self) -> N
name="func",
args=[
ArgSig(name="x"),
ArgSig(name="/"),
ArgSig(name="y"),
ArgSig(name="*"),
ArgSig(name="z"),
ArgSig("**kwargs"),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ class Point:
degree: ClassVar[Point.AngleUnit] = ...
radian: ClassVar[Point.AngleUnit] = ...
def __init__(self, value: typing.SupportsInt) -> None: ...
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __index__(self) -> int: ...
def __int__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __eq__(self, other: object, /) -> bool: ...
def __hash__(self, /) -> int: ...
def __index__(self, /) -> int: ...
def __int__(self, /) -> int: ...
def __ne__(self, other: object, /) -> bool: ...
@property
def name(self) -> str: ...
@property
Expand All @@ -28,11 +28,11 @@ class Point:
mm: ClassVar[Point.LengthUnit] = ...
pixel: ClassVar[Point.LengthUnit] = ...
def __init__(self, value: typing.SupportsInt) -> None: ...
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __index__(self) -> int: ...
def __int__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __eq__(self, other: object, /) -> bool: ...
def __hash__(self, /) -> int: ...
def __index__(self, /) -> int: ...
def __int__(self, /) -> int: ...
def __ne__(self, other: object, /) -> bool: ...
@property
def name(self) -> str: ...
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class Point:
radian: ClassVar[Point.AngleUnit] = ...
def __init__(self, value: typing.SupportsInt) -> None:
"""__init__(self: pybind11_fixtures.demo.Point.AngleUnit, value: typing.SupportsInt) -> None"""
def __eq__(self, other: object) -> bool:
def __eq__(self, other: object, /) -> bool:
"""__eq__(self: object, other: object, /) -> bool"""
def __hash__(self) -> int:
def __hash__(self, /) -> int:
"""__hash__(self: object, /) -> int"""
def __index__(self) -> int:
def __index__(self, /) -> int:
"""__index__(self: pybind11_fixtures.demo.Point.AngleUnit, /) -> int"""
def __int__(self) -> int:
def __int__(self, /) -> int:
"""__int__(self: pybind11_fixtures.demo.Point.AngleUnit, /) -> int"""
def __ne__(self, other: object) -> bool:
def __ne__(self, other: object, /) -> bool:
"""__ne__(self: object, other: object, /) -> bool"""
@property
def name(self) -> str:
Expand All @@ -52,15 +52,15 @@ class Point:
pixel: ClassVar[Point.LengthUnit] = ...
def __init__(self, value: typing.SupportsInt) -> None:
"""__init__(self: pybind11_fixtures.demo.Point.LengthUnit, value: typing.SupportsInt) -> None"""
def __eq__(self, other: object) -> bool:
def __eq__(self, other: object, /) -> bool:
"""__eq__(self: object, other: object, /) -> bool"""
def __hash__(self) -> int:
def __hash__(self, /) -> int:
"""__hash__(self: object, /) -> int"""
def __index__(self) -> int:
def __index__(self, /) -> int:
"""__index__(self: pybind11_fixtures.demo.Point.LengthUnit, /) -> int"""
def __int__(self) -> int:
def __int__(self, /) -> int:
"""__int__(self: pybind11_fixtures.demo.Point.LengthUnit, /) -> int"""
def __ne__(self, other: object) -> bool:
def __ne__(self, other: object, /) -> bool:
"""__ne__(self: object, other: object, /) -> bool"""
@property
def name(self) -> str:
Expand Down