Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support typing_extensions.overload #12602

Merged
merged 2 commits into from
Apr 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType,
OVERLOAD_NAMES,
)
from mypy.sametypes import is_same_type
from mypy.messages import (
Expand Down Expand Up @@ -3981,7 +3982,7 @@ def visit_decorator(self, e: Decorator) -> None:
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, 'typing.overload'):
if refers_to_fullname(d, OVERLOAD_NAMES):
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
ASSERT_TYPE_NAMES, is_named_instance,
ASSERT_TYPE_NAMES, OVERLOAD_NAMES, is_named_instance,
)
from mypy.typeops import function_type, get_type_vars
from mypy.type_visitor import TypeQuery
Expand Down Expand Up @@ -835,7 +835,7 @@ def analyze_overload_sigs_and_impl(
if isinstance(item, Decorator):
callable = function_type(item.func, self.named_type('builtins.function'))
assert isinstance(callable, CallableType)
if not any(refers_to_fullname(dec, 'typing.overload')
if not any(refers_to_fullname(dec, OVERLOAD_NAMES)
for dec in item.decorators):
if i == len(defn.items) - 1 and not self.is_stub_file:
# Last item outside a stub is impl
Expand Down
23 changes: 14 additions & 9 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from collections import defaultdict

from typing import (
List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast,
List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast,
)
from typing_extensions import Final

Expand Down Expand Up @@ -84,7 +84,7 @@
from mypy.options import Options as MypyOptions
from mypy.types import (
Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance,
AnyType, get_proper_type
AnyType, get_proper_type, OVERLOAD_NAMES
)
from mypy.visitor import NodeVisitor
from mypy.find_sources import create_source_list, InvalidSourceList
Expand All @@ -93,6 +93,10 @@
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
from mypy.moduleinspect import ModuleInspect

TYPING_MODULE_NAMES: Final = (
'typing',
'typing_extensions',
)

# Common ways of naming package containing vendored modules.
VENDOR_PACKAGES: Final = [
Expand Down Expand Up @@ -768,13 +772,15 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup
self.add_decorator('property')
self.add_decorator('abc.abstractmethod')
is_abstract = True
elif self.refers_to_fullname(name, 'typing.overload'):
elif self.refers_to_fullname(name, OVERLOAD_NAMES):
self.add_decorator(name)
self.add_typing_import('overload')
is_overload = True
return is_abstract, is_overload

def refers_to_fullname(self, name: str, fullname: str) -> bool:
def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool:
if isinstance(fullname, tuple):
return any(self.refers_to_fullname(name, fname) for fname in fullname)
module, short = fullname.rsplit('.', 1)
return (self.import_tracker.module_for.get(name) == module and
(name == short or
Expand Down Expand Up @@ -825,8 +831,8 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) ->
expr.expr.name + '.coroutine',
expr.expr.name)
elif (isinstance(expr.expr, NameExpr) and
(expr.expr.name == 'typing' or
self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
(expr.expr.name in TYPING_MODULE_NAMES or
self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and
expr.name == 'overload'):
self.import_tracker.require_name(expr.expr.name)
self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
Expand Down Expand Up @@ -1060,7 +1066,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
and name not in self.referenced_names
and (not self._all_ or name in IGNORED_DUNDERS)
and not is_private
and module not in ('abc', 'typing', 'asyncio')):
and module not in ('abc', 'asyncio') + TYPING_MODULE_NAMES):
# An imported name that is never referenced in the module is assumed to be
# exported, unless there is an explicit __all__. Note that we need to special
# case 'abc' since some references are deleted during semantic analysis.
Expand Down Expand Up @@ -1118,8 +1124,7 @@ def get_init(self, lvalue: str, rvalue: Expression,
typename = self.print_annotation(annotation)
if (isinstance(annotation, UnboundType) and not annotation.args and
annotation.name == 'Final' and
self.import_tracker.module_for.get('Final') in ('typing',
'typing_extensions')):
self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES):
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += '[{}]'.format(final_arg)
Expand Down
3 changes: 1 addition & 2 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,8 @@ def apply_decorator_to_funcitem(
return None
if decorator.fullname in (
"builtins.staticmethod",
"typing.overload",
"abc.abstractmethod",
):
) or decorator.fullname in mypy.types.OVERLOAD_NAMES:
return func
if decorator.fullname == "builtins.classmethod":
assert func.arguments[0].variable.name in ("cls", "metacls")
Expand Down
5 changes: 5 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@
'typing_extensions.assert_type',
)

OVERLOAD_NAMES: Final = (
'typing.overload',
'typing_extensions.overload',
)

# Attributes that can optionally be defined in the body of a subclass of
# enum.Enum but are removed from the class __dict__ by EnumMeta.
ENUM_REMOVED_PROPS: Final = (
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class A: pass
class B: pass
[builtins fixtures/isinstance.pyi]

[case testTypingExtensionsOverload]
from typing import Any
from typing_extensions import overload
@overload
def f(x: 'A') -> 'B': ...
@overload
def f(x: 'B') -> 'A': ...

def f(x: Any) -> Any:
pass

reveal_type(f(A())) # N: Revealed type is "__main__.B"
reveal_type(f(B())) # N: Revealed type is "__main__.A"

class A: pass
class B: pass
[builtins fixtures/isinstance.pyi]

[case testOverloadNeedsImplementation]
from typing import overload, Any
@overload # E: An overloaded function outside a stub file must have an implementation
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/typing_extensions.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TypeVar, Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
from typing import TYPE_CHECKING as TYPE_CHECKING
from typing import NewType as NewType
from typing import NewType as NewType, overload as overload

import sys

Expand Down
79 changes: 78 additions & 1 deletion test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2461,13 +2461,58 @@ class A:
def f(self, x: Tuple[int, int]) -> int: ...


@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: Tuple[int, int]) -> int: ...

[case testOverload_fromTypingExtensionsImport]
from typing import Tuple, Union
from typing_extensions import overload

class A:
@overload
def f(self, x: int, y: int) -> int:
...

@overload
def f(self, x: Tuple[int, int]) -> int:
...

def f(self, *args: Union[int, Tuple[int, int]]) -> int:
pass

@overload
def f(x: int, y: int) -> int:
...

@overload
def f(x: Tuple[int, int]) -> int:
...

def f(*args: Union[int, Tuple[int, int]]) -> int:
pass


[out]
from typing import Tuple
from typing_extensions import overload

class A:
@overload
def f(self, x: int, y: int) -> int: ...
@overload
def f(self, x: Tuple[int, int]) -> int: ...


@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: Tuple[int, int]) -> int: ...

[case testOverload_importTyping]
import typing
import typing_extensions

class A:
@typing.overload
Expand Down Expand Up @@ -2506,9 +2551,21 @@ def f(x: typing.Tuple[int, int]) -> int:
def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass

@typing_extensions.overload
def g(x: int, y: int) -> int:
...

@typing_extensions.overload
def g(x: typing.Tuple[int, int]) -> int:
...

def g(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass


[out]
import typing
import typing_extensions

class A:
@typing.overload
Expand All @@ -2527,10 +2584,14 @@ class A:
def f(x: int, y: int) -> int: ...
@typing.overload
def f(x: typing.Tuple[int, int]) -> int: ...

@typing_extensions.overload
def g(x: int, y: int) -> int: ...
@typing_extensions.overload
def g(x: typing.Tuple[int, int]) -> int: ...

[case testOverload_importTypingAs]
import typing as t
import typing_extensions as te

class A:
@t.overload
Expand Down Expand Up @@ -2570,8 +2631,20 @@ def f(*args: t.Union[int, t.Tuple[int, int]]) -> int:
pass


@te.overload
def g(x: int, y: int) -> int:
...

@te.overload
def g(x: t.Tuple[int, int]) -> int:
...

def g(*args: t.Union[int, t.Tuple[int, int]]) -> int:
pass

[out]
import typing as t
import typing_extensions as te

class A:
@t.overload
Expand All @@ -2590,6 +2663,10 @@ class A:
def f(x: int, y: int) -> int: ...
@t.overload
def f(x: t.Tuple[int, int]) -> int: ...
@te.overload
def g(x: int, y: int) -> int: ...
@te.overload
def g(x: t.Tuple[int, int]) -> int: ...

[case testProtocol_semanal]
from typing import Protocol, TypeVar
Expand Down