Skip to content

Commit 8642f3d

Browse files
committed
Fix inference bug
1 parent 86993a0 commit 8642f3d

File tree

10 files changed

+77
-35
lines changed

10 files changed

+77
-35
lines changed

mypy/checkexpr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,26 @@ def infer_function_type_arguments(
17761776
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
17771777
)
17781778

1779+
return_type = get_proper_type(callee_type.ret_type)
1780+
if isinstance(return_type, CallableType):
1781+
# fixup:
1782+
# def [T] () -> def (T) -> T
1783+
# into
1784+
# def () -> def [T] (T) -> T
1785+
for i, argument in enumerate(inferred_args):
1786+
if isinstance(get_proper_type(argument), UninhabitedType):
1787+
inferred_args[i] = callee_type.variables[i]
1788+
1789+
# handle multiple type variables
1790+
return_type = return_type.copy_modified(
1791+
variables=[*return_type.variables, callee_type.variables[i]]
1792+
)
1793+
1794+
callee_type = callee_type.copy_modified(
1795+
# am I allowed to assign the get_proper_type'd thing?
1796+
ret_type=return_type
1797+
)
1798+
17791799
if (
17801800
callee_type.special_sig == "dict"
17811801
and len(inferred_args) == 2

mypy/constraints.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Iterable, List, Sequence
5+
from typing import TYPE_CHECKING, Iterable, List, Sequence, Union
66
from typing_extensions import Final
77

88
import mypy.subtypes
@@ -713,26 +713,37 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
713713
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
714714
suffix = suffix.copy_modified(from_concatenate=from_concat)
715715

716-
717716
prefix = mapped_arg.prefix
718717
length = len(prefix.arg_types)
719718
if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
720719
# no such thing as variance for ParamSpecs
721720
# TODO: is there a case I am missing?
722-
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix.copy_modified(
723-
arg_types=suffix.arg_types[length:],
724-
arg_kinds=suffix.arg_kinds[length:],
725-
arg_names=suffix.arg_names[length:],
726-
)))
721+
res.append(
722+
Constraint(
723+
mapped_arg,
724+
SUPERTYPE_OF,
725+
suffix.copy_modified(
726+
arg_types=suffix.arg_types[length:],
727+
arg_kinds=suffix.arg_kinds[length:],
728+
arg_names=suffix.arg_names[length:],
729+
),
730+
)
731+
)
727732
elif isinstance(suffix, ParamSpecType):
728733
suffix_prefix = suffix.prefix
729-
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix.copy_modified(
730-
prefix=suffix_prefix.copy_modified(
731-
arg_types=suffix_prefix.arg_types[length:],
732-
arg_kinds=suffix_prefix.arg_kinds[length:],
733-
arg_names=suffix_prefix.arg_names[length:]
734+
res.append(
735+
Constraint(
736+
mapped_arg,
737+
SUPERTYPE_OF,
738+
suffix.copy_modified(
739+
prefix=suffix_prefix.copy_modified(
740+
arg_types=suffix_prefix.arg_types[length:],
741+
arg_kinds=suffix_prefix.arg_kinds[length:],
742+
arg_names=suffix_prefix.arg_names[length:],
743+
)
744+
),
734745
)
735-
)))
746+
)
736747
else:
737748
# This case should have been handled above.
738749
assert not isinstance(tvar, TypeVarTupleType)
@@ -947,12 +958,15 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
947958
prefix_len = len(prefix.arg_types)
948959
cactual_ps = cactual.param_spec()
949960

961+
cactual_prefix: Union[Parameters, CallableType]
950962
if cactual_ps:
951963
cactual_prefix = cactual_ps.prefix
952964
else:
953965
cactual_prefix = cactual
954966

955-
max_prefix_len = len([k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)])
967+
max_prefix_len = len(
968+
[k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)]
969+
)
956970
prefix_len = min(prefix_len, max_prefix_len)
957971

958972
# we could check the prefixes match here, but that should be caught elsewhere.
@@ -970,13 +984,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
970984
)
971985
)
972986
else:
973-
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps.copy_modified(
974-
prefix=cactual_prefix.copy_modified(
975-
arg_types=cactual_prefix.arg_types[prefix_len:],
976-
arg_kinds=cactual_prefix.arg_kinds[prefix_len:],
977-
arg_names=cactual_prefix.arg_names[prefix_len:]
987+
# guaranteed due to if conditions
988+
assert isinstance(cactual_prefix, Parameters)
989+
990+
res.append(
991+
Constraint(
992+
param_spec,
993+
SUBTYPE_OF,
994+
cactual_ps.copy_modified(
995+
prefix=cactual_prefix.copy_modified(
996+
arg_types=cactual_prefix.arg_types[prefix_len:],
997+
arg_kinds=cactual_prefix.arg_kinds[prefix_len:],
998+
arg_names=cactual_prefix.arg_names[prefix_len:],
999+
)
1000+
),
9781001
)
979-
)))
1002+
)
9801003

9811004
# compare prefixes
9821005
cactual_prefix = cactual.copy_modified(

mypy/erasetype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
175175
return t.prefix.copy_modified(
176176
arg_types=t.prefix.arg_types + [self.replacement, self.replacement],
177177
arg_kinds=t.prefix.arg_kinds + [ARG_STAR, ARG_STAR2],
178-
arg_names=t.prefix.arg_names + [None, None]
178+
arg_names=t.prefix.arg_names + [None, None],
179179
)
180180
return t
181181

mypy/expandtype.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def freshen_function_type_vars(callee: F) -> F:
126126
if isinstance(v, TypeVarType):
127127
tv: TypeVarLikeType = TypeVarType.new_unification_variable(v)
128128
elif isinstance(v, TypeVarTupleType):
129-
assert isinstance(v, TypeVarTupleType)
130129
tv = TypeVarTupleType.new_unification_variable(v)
131130
else:
132131
assert isinstance(v, ParamSpecType)

mypy/nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2515,7 +2515,6 @@ def __init__(
25152515
self, name: str, fullname: str, upper_bound: mypy.types.Type, variance: int = INVARIANT
25162516
) -> None:
25172517
super().__init__(name, fullname, upper_bound, variance)
2518-
assert isinstance(upper_bound, (mypy.types.CallableType, mypy.types.Parameters))
25192518

25202519
def accept(self, visitor: ExpressionVisitor[T]) -> T:
25212520
return visitor.visit_paramspec_expr(self)

mypy/semanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5606,7 +5606,7 @@ def top_caller(self) -> Parameters:
56065606
return Parameters(
56075607
arg_types=[self.object_type(), self.object_type()],
56085608
arg_kinds=[ARG_STAR, ARG_STAR2],
5609-
arg_names=[None, None]
5609+
arg_names=[None, None],
56105610
)
56115611

56125612
def str_type(self) -> Instance:

mypy/strconv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,9 @@ def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> str:
484484
a += ["Variance(COVARIANT)"]
485485
if o.variance == mypy.nodes.CONTRAVARIANT:
486486
a += ["Variance(CONTRAVARIANT)"]
487-
if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
488-
a += [f"UpperBound({o.upper_bound})"]
487+
# ParamSpecs do not have upper bounds!!! (should this be left for future proofing?)
488+
# if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"):
489+
# a += [f"UpperBound({o.upper_bound})"]
489490
return self.dump(a, o)
490491

491492
def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> str:

mypy/types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,6 @@ def __init__(
675675
super().__init__(name, fullname, id, upper_bound, line=line, column=column)
676676
self.flavor = flavor
677677
self.prefix = prefix or Parameters([], [], [])
678-
assert flavor != ParamSpecFlavor.BARE or isinstance(upper_bound, (CallableType, Parameters))
679678

680679
@staticmethod
681680
def new_unification_variable(old: ParamSpecType) -> ParamSpecType:
@@ -1995,8 +1994,8 @@ def param_spec(self) -> ParamSpecType | None:
19951994
upper_bound=Parameters(
19961995
arg_types=[any_type, any_type],
19971996
arg_kinds=[ARG_STAR, ARG_STAR2],
1998-
arg_names=[None, None]
1999-
)
1997+
arg_names=[None, None],
1998+
),
20001999
)
20012000

20022001
def expand_param_spec(

test-data/unit/check-inference.test

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2956,8 +2956,11 @@ T = TypeVar('T')
29562956

29572957
def f(x: Optional[T] = None) -> Callable[..., T]: ...
29582958

2959-
x = f() # E: Need type annotation for "x"
2959+
# TODO: should this warn about needed an annotation? This behavior still _works_...
2960+
x = f()
2961+
reveal_type(x) # N: Revealed type is "def [T] (*Any, **Any) -> T`1"
29602962
y = x
2963+
reveal_type(y) # N: Revealed type is "def [T] (*Any, **Any) -> T`1"
29612964

29622965
[case testDontNeedAnnotationForCallable]
29632966
from typing import TypeVar, Optional, Callable, NoReturn

test-data/unit/check-parameter-specification.test

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,9 +1552,7 @@ class Example(Generic[P]):
15521552
def test(ex: Example[P]) -> Example[Concatenate[int, P]]:
15531553
...
15541554

1555-
ex: Example[int] = test()(reveal_type(Example())) # N: Revealed type is "__main__.Example[<nothing>]"
1556-
# TODO: fix
1557-
reveal_type(test()(Example[int]())) # N: Revealed type is "__main__.Example[<nothing>]" \
1558-
# E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[<nothing>]"
1559-
ex = test()(Example[int]()) # E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[<nothing>]"
1555+
ex: Example[int] = test()(reveal_type(Example())) # N: Revealed type is "__main__.Example[[]]"
1556+
reveal_type(test()(Example[int]())) # N: Revealed type is "__main__.Example[[builtins.int, builtins.int]]"
1557+
ex = test()(Example[int]()) # E: Argument 1 has incompatible type "Example[[int]]"; expected "Example[[]]"
15601558
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)