Skip to content

Commit 626ff68

Browse files
authored
Improve usage of outer context for inference (#5699)
Fixes #4872 Fixes #3876 Fixes #2678 Fixes #5199 Fixes #5493 (It also fixes a bunch of similar issues previously closed as duplicates, except one, see below). This PR fixes a problems when mypy commits to soon to using outer context for type inference. This is done by: * Postponing inference to inner (argument) context in situations where type inferred from outer (return) context doesn't satisfy bounds or constraints. * Adding a special case for situation where optional return is inferred against optional context. In such situation, unwrapping the optional is a better idea in 99% of cases. (Note: this doesn't affect type safety, only gives empirically more reasonable inferred types.) In general, instead of adding a special case, it would be better to use inner and outer context at the same time, but this a big change (see comment in code), and using the simple special case fixes majority of issues. Among reported issues, only #5311 will stay unfixed.
1 parent baa4725 commit 626ff68

File tree

8 files changed

+393
-47
lines changed

8 files changed

+393
-47
lines changed

mypy/applytype.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99

1010

1111
def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optional[Type]],
12-
msg: MessageBuilder, context: Context) -> CallableType:
12+
msg: MessageBuilder, context: Context,
13+
skip_unsatisfied: bool = False) -> CallableType:
1314
"""Apply generic type arguments to a callable type.
1415
1516
For example, applying [int] to 'def [T] (T) -> T' results in
1617
'def (int) -> int'.
1718
1819
Note that each type can be None; in this case, it will not be applied.
20+
21+
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
22+
bound or constraints, instead of giving an error.
1923
"""
2024
tvars = callable.variables
2125
assert len(tvars) == len(orig_types)
@@ -25,7 +29,9 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
2529
for i, type in enumerate(types):
2630
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
2731
values = callable.variables[i].values
28-
if values and type:
32+
if type is None:
33+
continue
34+
if values:
2935
if isinstance(type, AnyType):
3036
continue
3137
if isinstance(type, TypeVarType) and type.values:
@@ -34,15 +40,31 @@ def apply_generic_arguments(callable: CallableType, orig_types: Sequence[Optiona
3440
if all(any(is_same_type(v, v1) for v in values)
3541
for v1 in type.values):
3642
continue
43+
matching = []
3744
for value in values:
3845
if mypy.subtypes.is_subtype(type, value):
39-
types[i] = value
40-
break
46+
matching.append(value)
47+
if matching:
48+
best = matching[0]
49+
# If there are more than one matching value, we select the narrowest
50+
for match in matching[1:]:
51+
if mypy.subtypes.is_subtype(match, best):
52+
best = match
53+
types[i] = best
4154
else:
42-
msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context)
43-
upper_bound = callable.variables[i].upper_bound
44-
if type and not mypy.subtypes.is_subtype(type, upper_bound):
45-
msg.incompatible_typevar_value(callable, type, callable.variables[i].name, context)
55+
if skip_unsatisfied:
56+
types[i] = None
57+
else:
58+
msg.incompatible_typevar_value(callable, type, callable.variables[i].name,
59+
context)
60+
else:
61+
upper_bound = callable.variables[i].upper_bound
62+
if not mypy.subtypes.is_subtype(type, upper_bound):
63+
if skip_unsatisfied:
64+
types[i] = None
65+
else:
66+
msg.incompatible_typevar_value(callable, type, callable.variables[i].name,
67+
context)
4668

4769
# Create a map from type variable id to target type.
4870
id_to_type = {} # type: Dict[TypeVarId, Type]

mypy/checker.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
3232
Instance, NoneTyp, strip_type, TypeType, TypeOfAny,
3333
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
34-
true_only, false_only, function_type, is_named_instance, union_items, TypeQuery
34+
true_only, false_only, function_type, is_named_instance, union_items, TypeQuery,
35+
is_optional, remove_optional
3536
)
3637
from mypy.sametypes import is_same_type, is_same_types
3738
from mypy.messages import MessageBuilder, make_inferred_type_note
@@ -3792,17 +3793,6 @@ def is_literal_none(n: Expression) -> bool:
37923793
return isinstance(n, NameExpr) and n.fullname == 'builtins.None'
37933794

37943795

3795-
def is_optional(t: Type) -> bool:
3796-
return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items)
3797-
3798-
3799-
def remove_optional(typ: Type) -> Type:
3800-
if isinstance(typ, UnionType):
3801-
return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)])
3802-
else:
3803-
return typ
3804-
3805-
38063796
def is_literal_not_implemented(n: Expression) -> bool:
38073797
return isinstance(n, NameExpr) and n.fullname == 'builtins.NotImplemented'
38083798

mypy/checkexpr.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from mypy.types import (
1919
Type, AnyType, CallableType, Overloaded, NoneTyp, TypeVarDef,
2020
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
21-
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, true_only,
22-
false_only, is_named_instance, function_type, callable_type, FunctionLike, StarType,
21+
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny,
22+
true_only, false_only, is_named_instance, function_type, callable_type, FunctionLike,
23+
StarType, is_optional, remove_optional, is_invariant_instance
2324
)
2425
from mypy.nodes import (
2526
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
@@ -30,7 +31,7 @@
3031
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
3132
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
3233
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
33-
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, ClassDef, Block, SymbolNode,
34+
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode,
3435
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, LITERAL_TYPE, REVEAL_TYPE
3536
)
3637
from mypy.literals import literal
@@ -819,20 +820,36 @@ def infer_function_type_arguments_using_context(
819820
# valid results.
820821
erased_ctx = replace_meta_vars(ctx, ErasedType())
821822
ret_type = callable.ret_type
822-
if isinstance(ret_type, TypeVarType):
823-
if ret_type.values or (not isinstance(ctx, Instance) or
824-
not ctx.args):
825-
# The return type is a type variable. If it has values, we can't easily restrict
826-
# type inference to conform to the valid values. If it's unrestricted, we could
827-
# infer a too general type for the type variable if we use context, and this could
828-
# result in confusing and spurious type errors elsewhere.
829-
#
830-
# Give up and just use function arguments for type inference. As an exception,
831-
# if the context is a generic instance type, actually use it as context, as
832-
# this *seems* to usually be the reasonable thing to do.
833-
#
834-
# See also github issues #462 and #360.
835-
ret_type = NoneTyp()
823+
if is_optional(ret_type) and is_optional(ctx):
824+
# If both the context and the return type are optional, unwrap the optional,
825+
# since in 99% cases this is what a user expects. In other words, we replace
826+
# Optional[T] <: Optional[int]
827+
# with
828+
# T <: int
829+
# while the former would infer T <: Optional[int].
830+
ret_type = remove_optional(ret_type)
831+
erased_ctx = remove_optional(erased_ctx)
832+
#
833+
# TODO: Instead of this hack and the one below, we need to use outer and
834+
# inner contexts at the same time. This is however not easy because of two
835+
# reasons:
836+
# * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
837+
# on both sides. (This is not too hard.)
838+
# * We need to update all the inference "infrastructure", so that all
839+
# variables in an expression are inferred at the same time.
840+
# (And this is hard, also we need to be careful with lambdas that require
841+
# two passes.)
842+
if isinstance(ret_type, TypeVarType) and not is_invariant_instance(ctx):
843+
# Another special case: the return type is a type variable. If it's unrestricted,
844+
# we could infer a too general type for the type variable if we use context,
845+
# and this could result in confusing and spurious type errors elsewhere.
846+
#
847+
# Give up and just use function arguments for type inference. As an exception,
848+
# if the context is an invariant instance type, actually use it as context, as
849+
# this *seems* to usually be the reasonable thing to do.
850+
#
851+
# See also github issues #462 and #360.
852+
return callable.copy_modified()
836853
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
837854
# Only substitute non-Uninhabited and non-erased types.
838855
new_args = [] # type: List[Optional[Type]]
@@ -841,7 +858,10 @@ def infer_function_type_arguments_using_context(
841858
new_args.append(None)
842859
else:
843860
new_args.append(arg)
844-
return self.apply_generic_arguments(callable, new_args, error_context)
861+
# Don't show errors after we have only used the outer context for inference.
862+
# We will use argument context to infer more variables.
863+
return self.apply_generic_arguments(callable, new_args, error_context,
864+
skip_unsatisfied=True)
845865

846866
def infer_function_type_arguments(self, callee_type: CallableType,
847867
args: List[Expression],
@@ -1609,9 +1629,10 @@ def check_arg(caller_type: Type, original_caller_type: Type, caller_kind: int,
16091629
return False
16101630

16111631
def apply_generic_arguments(self, callable: CallableType, types: Sequence[Optional[Type]],
1612-
context: Context) -> CallableType:
1632+
context: Context, skip_unsatisfied: bool = False) -> CallableType:
16131633
"""Simple wrapper around mypy.applytype.apply_generic_arguments."""
1614-
return applytype.apply_generic_arguments(callable, types, self.msg, context)
1634+
return applytype.apply_generic_arguments(callable, types, self.msg, context,
1635+
skip_unsatisfied=skip_unsatisfied)
16151636

16161637
def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type:
16171638
"""Visit member expression (of form e.id)."""

mypy/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,23 @@ def union_items(typ: Type) -> List[Type]:
19181918
return [typ]
19191919

19201920

1921+
def is_invariant_instance(tp: Type) -> bool:
1922+
if not isinstance(tp, Instance) or not tp.args:
1923+
return False
1924+
return any(v.variance == INVARIANT for v in tp.type.defn.type_vars)
1925+
1926+
1927+
def is_optional(t: Type) -> bool:
1928+
return isinstance(t, UnionType) and any(isinstance(e, NoneTyp) for e in t.items)
1929+
1930+
1931+
def remove_optional(typ: Type) -> Type:
1932+
if isinstance(typ, UnionType):
1933+
return UnionType.make_union([t for t in typ.items if not isinstance(t, NoneTyp)])
1934+
else:
1935+
return typ
1936+
1937+
19211938
names = globals().copy() # type: Final
19221939
names.pop('NOT_READY', None)
19231940
deserialize_map = {

test-data/unit/check-generics.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def fun2(v: Vec[T], scale: T) -> Vec[T]:
849849
return v
850850

851851
reveal_type(fun1([(1, 1)])) # E: Revealed type is 'builtins.int*'
852-
fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[int, int]]"
852+
fun1(1) # E: Argument 1 to "fun1" has incompatible type "int"; expected "List[Tuple[bool, bool]]"
853853
fun1([(1, 'x')]) # E: Cannot infer type argument 1 of "fun1"
854854

855855
reveal_type(fun2([(1, 1)], 1)) # E: Revealed type is 'builtins.list[Tuple[builtins.int*, builtins.int*]]'

0 commit comments

Comments
 (0)