Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference for overloaded __call__ with generic self #16053

Merged
merged 15 commits into from
Sep 19, 2023
Merged
4 changes: 3 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,7 @@ def check_call(
callable_node: Expression | None = None,
callable_name: str | None = None,
object_type: Type | None = None,
original_type: Type | None = None,
) -> tuple[Type, Type]:
"""Type check a call.

Expand Down Expand Up @@ -1538,7 +1539,7 @@ def check_call(
is_super=False,
is_operator=True,
msg=self.msg,
original_type=callee,
original_type=original_type or callee,
chk=self.chk,
in_literal_context=self.is_literal_context(),
)
Expand Down Expand Up @@ -1579,6 +1580,7 @@ def check_call(
callable_node,
callable_name,
object_type,
original_type=callee,
)
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
Expand Down
13 changes: 6 additions & 7 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,12 @@ def analyze_instance_member_access(
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class)
# TODO: should we skip these steps for static methods as well?
# Since generic static methods should not be allowed.
Expand Down
51 changes: 29 additions & 22 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,19 +454,22 @@ def visit_instance(self, left: Instance) -> bool:
if isinstance(unpacked, Instance):
return self._is_subtype(left, unpacked)
if left.type.has_base(right.partial_fallback.type.fullname):
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
if len(mapped.args) == 1 and isinstance(mapped.args[0], UnpackType):
unpacked = get_proper_type(mapped.args[0].type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if isinstance(get_proper_type(unpacked.args[0]), AnyType):
return not self.proper_subtype
if mapped.type.fullname == "builtins.tuple" and isinstance(
get_proper_type(mapped.args[0]), AnyType
):
return not self.proper_subtype
if not self.proper_subtype:
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
for arg in map(get_proper_type, mapped.args):
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if not isinstance(unpacked, Instance):
break
assert unpacked.type.fullname == "builtins.tuple"
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
break
elif not isinstance(arg, AnyType):
break
else:
return True
return False
if isinstance(right, TypeVarTupleType):
# tuple[Any, ...] is like Any in the world of tuples (see special case above).
Expand Down Expand Up @@ -534,15 +537,19 @@ def visit_instance(self, left: Instance) -> bool:
right_args = (
right_prefix + (TupleType(list(right_middle), fallback),) + right_suffix
)
if len(t.args) == 1 and isinstance(t.args[0], UnpackType):
unpacked = get_proper_type(t.args[0].type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if (
isinstance(get_proper_type(unpacked.args[0]), AnyType)
and not self.proper_subtype
):
return True
if not self.proper_subtype:
for arg in map(get_proper_type, t.args):
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if not isinstance(unpacked, Instance):
break
assert unpacked.type.fullname == "builtins.tuple"
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
break
elif not isinstance(arg, AnyType):
break
else:
return True
type_params = zip(left_args, right_args, right.type.defn.type_vars)
else:
type_params = zip(t.args, right.args, right.type.defn.type_vars)
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6650,3 +6650,27 @@ def d(x: int) -> int: ...
def d(f: int, *, x: int) -> str: ...
def d(*args, **kwargs): ...
[builtins fixtures/tuple.pyi]

[case testOverloadCallableGenericSelf]
from typing import Any, TypeVar, Generic, overload, reveal_type

T = TypeVar("T")

class MyCallable(Generic[T]):
def __init__(self, t: T):
self.t = t

@overload
def __call__(self: "MyCallable[int]") -> str: ...
@overload
def __call__(self: "MyCallable[str]") -> int: ...
def __call__(self): ...

c = MyCallable(5)
reveal_type(c) # N: Revealed type is "__main__.MyCallable[builtins.int]"
reveal_type(c()) # N: Revealed type is "builtins.str"

c2 = MyCallable("test")
reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]"
reveal_type(c2()) # should be int # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
14 changes: 14 additions & 0 deletions test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,21 @@ def foo(o: CallableTuple) -> int:
class CallableTuple(Tuple[str, int]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTypeTupleGenericCall]
from typing import Generic, Tuple, TypeVar

T = TypeVar('T')

def foo(o: CallableTuple[int]) -> int:
reveal_type(o) # N: Revealed type is "Tuple[builtins.str, builtins.int, fallback=__main__.CallableTuple[builtins.int]]"
reveal_type(o.count(3)) # N: Revealed type is "builtins.int"
return o(1, 2)

class CallableTuple(Tuple[str, T]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTupleCompatibleWithSequence]
Expand Down