Skip to content

Commit 6998787

Browse files
elazarggvanrossum
authored andcommitted
apply_generic_arguments: replace runtime checks with assertions (#2346)
Part of #2272. The only place where the number of type arguments in type application might not match is when the user does that; this is explicitly checked in visit_type_application. Other places involve type inference, which must be correct "by construction" and if it isn't, that's a bug we want to catch, instead of returning AnyType. It also allows us to remove casts.
1 parent 61d4c94 commit 6998787

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

mypy/applytype.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import mypy.subtypes
44
from mypy.sametypes import is_same_type
55
from mypy.expandtype import expand_type
6-
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, Void
6+
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType
77
from mypy.messages import MessageBuilder
88
from mypy.nodes import Context
99

1010

1111
def apply_generic_arguments(callable: CallableType, types: List[Type],
12-
msg: MessageBuilder, context: Context) -> Type:
12+
msg: MessageBuilder, context: Context) -> CallableType:
1313
"""Apply generic type arguments to a callable type.
1414
1515
For example, applying [int] to 'def [T] (T) -> T' results in
@@ -18,10 +18,7 @@ def apply_generic_arguments(callable: CallableType, types: List[Type],
1818
Note that each type can be None; in this case, it will not be applied.
1919
"""
2020
tvars = callable.variables
21-
if len(tvars) != len(types):
22-
msg.incompatible_type_application(len(tvars), len(types), context)
23-
return AnyType()
24-
21+
assert len(tvars) == len(types)
2522
# Check that inferred type variable values are compatible with allowed
2623
# values and bounds. Also, promote subtype values to allowed values.
2724
types = types[:]

mypy/checkexpr.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,7 @@ def infer_function_type_arguments_using_context(
470470
new_args.append(None)
471471
else:
472472
new_args.append(arg)
473-
return cast(CallableType, self.apply_generic_arguments(callable, new_args,
474-
error_context))
473+
return self.apply_generic_arguments(callable, new_args, error_context)
475474

476475
def infer_function_type_arguments(self, callee_type: CallableType,
477476
args: List[Expression],
@@ -561,9 +560,8 @@ def infer_function_type_arguments_pass2(
561560
for i, arg in enumerate(inferred_args):
562561
if isinstance(arg, (NoneTyp, UninhabitedType)) or has_erased_component(arg):
563562
inferred_args[i] = None
563+
callee_type = self.apply_generic_arguments(callee_type, inferred_args, context)
564564

565-
callee_type = cast(CallableType, self.apply_generic_arguments(
566-
callee_type, inferred_args, context))
567565
arg_types = self.infer_arg_types_in_context2(
568566
callee_type, args, arg_kinds, formal_to_actual)
569567

@@ -609,8 +607,7 @@ def apply_inferred_arguments(self, callee_type: CallableType,
609607
# Apply the inferred types to the function type. In this case the
610608
# return type must be CallableType, since we give the right number of type
611609
# arguments.
612-
return cast(CallableType, self.apply_generic_arguments(callee_type,
613-
inferred_args, context))
610+
return self.apply_generic_arguments(callee_type, inferred_args, context)
614611

615612
def check_argument_count(self, callee: CallableType, actual_types: List[Type],
616613
actual_kinds: List[int], actual_names: List[str],
@@ -724,10 +721,10 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int],
724721

725722
# There may be some remaining tuple varargs items that haven't
726723
# been checked yet. Handle them.
724+
tuplet = arg_types[actual]
727725
if (callee.arg_kinds[i] == nodes.ARG_STAR and
728726
arg_kinds[actual] == nodes.ARG_STAR and
729-
isinstance(arg_types[actual], TupleType)):
730-
tuplet = cast(TupleType, arg_types[actual])
727+
isinstance(tuplet, TupleType)):
731728
while tuple_counter[0] < len(tuplet.items):
732729
actual_type = get_actual_type(arg_type,
733730
arg_kinds[actual],
@@ -880,22 +877,10 @@ def check_arg(caller_type: Type, original_caller_type: Type, caller_kind: int,
880877
return ok
881878

882879
def apply_generic_arguments(self, callable: CallableType, types: List[Type],
883-
context: Context) -> Type:
880+
context: Context) -> CallableType:
884881
"""Simple wrapper around mypy.applytype.apply_generic_arguments."""
885882
return applytype.apply_generic_arguments(callable, types, self.msg, context)
886883

887-
def apply_generic_arguments2(self, overload: Overloaded, types: List[Type],
888-
context: Context) -> Type:
889-
items = [] # type: List[CallableType]
890-
for item in overload.items():
891-
applied = self.apply_generic_arguments(item, types, context)
892-
if isinstance(applied, CallableType):
893-
items.append(applied)
894-
else:
895-
# There was an error.
896-
return AnyType()
897-
return Overloaded(items)
898-
899884
def visit_member_expr(self, e: MemberExpr) -> Type:
900885
"""Visit member expression (of form e.id)."""
901886
self.chk.module_refs.update(extract_refexpr_names(e))
@@ -1375,9 +1360,19 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
13751360
"""Type check a type application (expr[type, ...])."""
13761361
tp = self.accept(tapp.expr)
13771362
if isinstance(tp, CallableType):
1363+
if len(tp.variables) != len(tapp.types):
1364+
self.msg.incompatible_type_application(len(tp.variables),
1365+
len(tapp.types), tapp)
1366+
return AnyType()
13781367
return self.apply_generic_arguments(tp, tapp.types, tapp)
1379-
if isinstance(tp, Overloaded):
1380-
return self.apply_generic_arguments2(tp, tapp.types, tapp)
1368+
elif isinstance(tp, Overloaded):
1369+
for item in tp.items():
1370+
if len(item.variables) != len(tapp.types):
1371+
self.msg.incompatible_type_application(len(item.variables),
1372+
len(tapp.types), tapp)
1373+
return AnyType()
1374+
return Overloaded([self.apply_generic_arguments(item, tapp.types, tapp)
1375+
for item in tp.items()])
13811376
return AnyType()
13821377

13831378
def visit_type_alias_expr(self, alias: TypeAliasExpr) -> Type:
@@ -1569,9 +1564,8 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType:
15691564
# they must be considered as indeterminate. We use ErasedType since it
15701565
# does not affect type inference results (it is for purposes like this
15711566
# only).
1572-
ctx = replace_meta_vars(ctx, ErasedType())
1573-
1574-
callable_ctx = cast(CallableType, ctx)
1567+
callable_ctx = replace_meta_vars(ctx, ErasedType())
1568+
assert isinstance(callable_ctx, CallableType)
15751569

15761570
arg_kinds = [arg.kind for arg in e.arguments]
15771571

mypy/subtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import cast, List, Dict, Callable
1+
from typing import List, Dict, Callable
22

33
from mypy.types import (
44
Type, AnyType, UnboundType, TypeVisitor, ErrorType, Void, NoneTyp,
@@ -333,7 +333,7 @@ def unify_generic_callable(type: CallableType, target: CallableType,
333333
return None
334334
msg = messages.temp_message_builder()
335335
applied = mypy.applytype.apply_generic_arguments(type, inferred_vars, msg, context=target)
336-
if msg.is_errors() or not isinstance(applied, CallableType):
336+
if msg.is_errors():
337337
return None
338338
return applied
339339

0 commit comments

Comments
 (0)