Skip to content

Commit

Permalink
Support typing_extensions.overload (#12602)
Browse files Browse the repository at this point in the history
This always existed in typing_extensions, but was an alias for
typing.overload. With python/typing#1140, it will actually make
a difference at runtime which one you use.

Note that this shouldn't change mypy's behaviour, since
we alias typing_extensions.overload to typing.overload
in typeshed, but this makes the logic less fragile.
  • Loading branch information
JelleZijlstra authored Apr 16, 2022
1 parent 10ba5c1 commit 0df8cf5
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 17 deletions.
5 changes: 3 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
is_named_instance, union_items, TypeQuery, LiteralType,
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType,
OVERLOAD_NAMES,
)
from mypy.sametypes import is_same_type
from mypy.messages import (
Expand Down Expand Up @@ -3981,7 +3982,7 @@ def visit_decorator(self, e: Decorator) -> None:
# may be different from the declared signature.
sig: Type = self.function_type(e.func)
for d in reversed(e.decorators):
if refers_to_fullname(d, 'typing.overload'):
if refers_to_fullname(d, OVERLOAD_NAMES):
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
continue
dec = self.expr_checker.accept(d)
Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
ASSERT_TYPE_NAMES, is_named_instance,
ASSERT_TYPE_NAMES, OVERLOAD_NAMES, is_named_instance,
)
from mypy.typeops import function_type, get_type_vars
from mypy.type_visitor import TypeQuery
Expand Down Expand Up @@ -835,7 +835,7 @@ def analyze_overload_sigs_and_impl(
if isinstance(item, Decorator):
callable = function_type(item.func, self.named_type('builtins.function'))
assert isinstance(callable, CallableType)
if not any(refers_to_fullname(dec, 'typing.overload')
if not any(refers_to_fullname(dec, OVERLOAD_NAMES)
for dec in item.decorators):
if i == len(defn.items) - 1 and not self.is_stub_file:
# Last item outside a stub is impl
Expand Down
23 changes: 14 additions & 9 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from collections import defaultdict

from typing import (
List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast,
List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast,
)
from typing_extensions import Final

Expand Down Expand Up @@ -84,7 +84,7 @@
from mypy.options import Options as MypyOptions
from mypy.types import (
Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance,
AnyType, get_proper_type
AnyType, get_proper_type, OVERLOAD_NAMES
)
from mypy.visitor import NodeVisitor
from mypy.find_sources import create_source_list, InvalidSourceList
Expand All @@ -93,6 +93,10 @@
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
from mypy.moduleinspect import ModuleInspect

TYPING_MODULE_NAMES: Final = (
'typing',
'typing_extensions',
)

# Common ways of naming package containing vendored modules.
VENDOR_PACKAGES: Final = [
Expand Down Expand Up @@ -768,13 +772,15 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup
self.add_decorator('property')
self.add_decorator('abc.abstractmethod')
is_abstract = True
elif self.refers_to_fullname(name, 'typing.overload'):
elif self.refers_to_fullname(name, OVERLOAD_NAMES):
self.add_decorator(name)
self.add_typing_import('overload')
is_overload = True
return is_abstract, is_overload

def refers_to_fullname(self, name: str, fullname: str) -> bool:
def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool:
if isinstance(fullname, tuple):
return any(self.refers_to_fullname(name, fname) for fname in fullname)
module, short = fullname.rsplit('.', 1)
return (self.import_tracker.module_for.get(name) == module and
(name == short or
Expand Down Expand Up @@ -825,8 +831,8 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) ->
expr.expr.name + '.coroutine',
expr.expr.name)
elif (isinstance(expr.expr, NameExpr) and
(expr.expr.name == 'typing' or
self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
(expr.expr.name in TYPING_MODULE_NAMES or
self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and
expr.name == 'overload'):
self.import_tracker.require_name(expr.expr.name)
self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
Expand Down Expand Up @@ -1060,7 +1066,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
and name not in self.referenced_names
and (not self._all_ or name in IGNORED_DUNDERS)
and not is_private
and module not in ('abc', 'typing', 'asyncio')):
and module not in ('abc', 'asyncio') + TYPING_MODULE_NAMES):
# An imported name that is never referenced in the module is assumed to be
# exported, unless there is an explicit __all__. Note that we need to special
# case 'abc' since some references are deleted during semantic analysis.
Expand Down Expand Up @@ -1118,8 +1124,7 @@ def get_init(self, lvalue: str, rvalue: Expression,
typename = self.print_annotation(annotation)
if (isinstance(annotation, UnboundType) and not annotation.args and
annotation.name == 'Final' and
self.import_tracker.module_for.get('Final') in ('typing',
'typing_extensions')):
self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES):
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += '[{}]'.format(final_arg)
Expand Down
3 changes: 1 addition & 2 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,8 @@ def apply_decorator_to_funcitem(
return None
if decorator.fullname in (
"builtins.staticmethod",
"typing.overload",
"abc.abstractmethod",
):
) or decorator.fullname in mypy.types.OVERLOAD_NAMES:
return func
if decorator.fullname == "builtins.classmethod":
assert func.arguments[0].variable.name in ("cls", "metacls")
Expand Down
5 changes: 5 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@
'typing_extensions.assert_type',
)

OVERLOAD_NAMES: Final = (
'typing.overload',
'typing_extensions.overload',
)

# Attributes that can optionally be defined in the body of a subclass of
# enum.Enum but are removed from the class __dict__ by EnumMeta.
ENUM_REMOVED_PROPS: Final = (
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class A: pass
class B: pass
[builtins fixtures/isinstance.pyi]

[case testTypingExtensionsOverload]
from typing import Any
from typing_extensions import overload
@overload
def f(x: 'A') -> 'B': ...
@overload
def f(x: 'B') -> 'A': ...

def f(x: Any) -> Any:
pass

reveal_type(f(A())) # N: Revealed type is "__main__.B"
reveal_type(f(B())) # N: Revealed type is "__main__.A"

class A: pass
class B: pass
[builtins fixtures/isinstance.pyi]

[case testOverloadNeedsImplementation]
from typing import overload, Any
@overload # E: An overloaded function outside a stub file must have an implementation
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/lib-stub/typing_extensions.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TypeVar, Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
from typing import TYPE_CHECKING as TYPE_CHECKING
from typing import NewType as NewType
from typing import NewType as NewType, overload as overload

import sys

Expand Down
79 changes: 78 additions & 1 deletion test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2461,13 +2461,58 @@ class A:
def f(self, x: Tuple[int, int]) -> int: ...


@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: Tuple[int, int]) -> int: ...

[case testOverload_fromTypingExtensionsImport]
from typing import Tuple, Union
from typing_extensions import overload

class A:
@overload
def f(self, x: int, y: int) -> int:
...

@overload
def f(self, x: Tuple[int, int]) -> int:
...

def f(self, *args: Union[int, Tuple[int, int]]) -> int:
pass

@overload
def f(x: int, y: int) -> int:
...

@overload
def f(x: Tuple[int, int]) -> int:
...

def f(*args: Union[int, Tuple[int, int]]) -> int:
pass


[out]
from typing import Tuple
from typing_extensions import overload

class A:
@overload
def f(self, x: int, y: int) -> int: ...
@overload
def f(self, x: Tuple[int, int]) -> int: ...


@overload
def f(x: int, y: int) -> int: ...
@overload
def f(x: Tuple[int, int]) -> int: ...

[case testOverload_importTyping]
import typing
import typing_extensions

class A:
@typing.overload
Expand Down Expand Up @@ -2506,9 +2551,21 @@ def f(x: typing.Tuple[int, int]) -> int:
def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass

@typing_extensions.overload
def g(x: int, y: int) -> int:
...

@typing_extensions.overload
def g(x: typing.Tuple[int, int]) -> int:
...

def g(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
pass


[out]
import typing
import typing_extensions

class A:
@typing.overload
Expand All @@ -2527,10 +2584,14 @@ class A:
def f(x: int, y: int) -> int: ...
@typing.overload
def f(x: typing.Tuple[int, int]) -> int: ...

@typing_extensions.overload
def g(x: int, y: int) -> int: ...
@typing_extensions.overload
def g(x: typing.Tuple[int, int]) -> int: ...

[case testOverload_importTypingAs]
import typing as t
import typing_extensions as te

class A:
@t.overload
Expand Down Expand Up @@ -2570,8 +2631,20 @@ def f(*args: t.Union[int, t.Tuple[int, int]]) -> int:
pass


@te.overload
def g(x: int, y: int) -> int:
...

@te.overload
def g(x: t.Tuple[int, int]) -> int:
...

def g(*args: t.Union[int, t.Tuple[int, int]]) -> int:
pass

[out]
import typing as t
import typing_extensions as te

class A:
@t.overload
Expand All @@ -2590,6 +2663,10 @@ class A:
def f(x: int, y: int) -> int: ...
@t.overload
def f(x: t.Tuple[int, int]) -> int: ...
@te.overload
def g(x: int, y: int) -> int: ...
@te.overload
def g(x: t.Tuple[int, int]) -> int: ...

[case testProtocol_semanal]
from typing import Protocol, TypeVar
Expand Down

0 comments on commit 0df8cf5

Please sign in to comment.