Skip to content

Fix crash due to checking type variable values too early #4384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,14 +1853,14 @@ def parse_file(self) -> None:

def semantic_analysis(self) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
patches = [] # type: List[Callable[[], None]]
patches = [] # type: List[Tuple[int, Callable[[], None]]]
with self.wrap_context():
self.manager.semantic_analyzer.visit_file(self.tree, self.xpath, self.options, patches)
self.patches = patches

def semantic_analysis_pass_three(self) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
patches = [] # type: List[Callable[[], None]]
patches = [] # type: List[Tuple[int, Callable[[], None]]]
with self.wrap_context():
self.manager.semantic_analyzer_pass3.visit_file(self.tree, self.xpath,
self.options, patches)
Expand All @@ -1869,7 +1869,8 @@ def semantic_analysis_pass_three(self) -> None:
self.patches = patches + self.patches

def semantic_analysis_apply_patches(self) -> None:
for patch_func in self.patches:
patches_by_priority = sorted(self.patches, key=lambda x: x[0])
for priority, patch_func in patches_by_priority:
patch_func()

def type_check_first_pass(self) -> None:
Expand Down
12 changes: 7 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy import join
from mypy.util import get_prefix
from mypy.semanal_shared import PRIORITY_FALLBACKS


T = TypeVar('T')
Expand Down Expand Up @@ -258,11 +259,12 @@ def __init__(self,
self.recurse_into_functions = True

def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
patches: List[Callable[[], None]]) -> None:
patches: List[Tuple[int, Callable[[], None]]]) -> None:
"""Run semantic analysis phase 2 over a file.

Add callbacks by mutating the patches list argument. They will be called
after all semantic analysis phases but before type checking.
Add (priority, callback) pairs by mutating the 'patches' list argument. They
will be called after all semantic analysis phases but before type checking,
lowest priority values first.
"""
self.recurse_into_functions = True
self.options = options
Expand Down Expand Up @@ -2454,7 +2456,7 @@ def patch() -> None:
# We can't calculate the complete fallback type until after semantic
# analysis, since otherwise MROs might be incomplete. Postpone a callback
# function that patches the fallback.
self.patches.append(patch)
self.patches.append((PRIORITY_FALLBACKS, patch))

def add_field(var: Var, is_initialized_in_class: bool = False,
is_property: bool = False) -> None:
Expand Down Expand Up @@ -2693,7 +2695,7 @@ def patch() -> None:
# We can't calculate the complete fallback type until after semantic
# analysis, since otherwise MROs might be incomplete. Postpone a callback
# function that patches the fallback.
self.patches.append(patch)
self.patches.append((PRIORITY_FALLBACKS, patch))
return info

def check_classvar(self, s: AssignmentStmt) -> None:
Expand Down
88 changes: 77 additions & 11 deletions mypy/semanal_pass3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from collections import OrderedDict
from typing import Dict, List, Callable, Optional, Union, Set, cast
from typing import Dict, List, Callable, Optional, Union, Set, cast, Tuple

from mypy import messages, experiments
from mypy.nodes import (
Expand All @@ -28,6 +28,9 @@
from mypy.traverser import TraverserVisitor
from mypy.typeanal import TypeAnalyserPass3, collect_any_types
from mypy.typevars import has_no_typevars
from mypy.semanal_shared import PRIORITY_FORWARD_REF, PRIORITY_TYPEVAR_VALUES
from mypy.subtypes import is_subtype
from mypy.sametypes import is_same_type
import mypy.semanal


Expand All @@ -48,7 +51,7 @@ def __init__(self, modules: Dict[str, MypyFile], errors: Errors,
self.recurse_into_functions = True

def visit_file(self, file_node: MypyFile, fnam: str, options: Options,
patches: List[Callable[[], None]]) -> None:
patches: List[Tuple[int, Callable[[], None]]]) -> None:
self.recurse_into_functions = True
self.errors.set_file(fnam, file_node.fullname())
self.options = options
Expand Down Expand Up @@ -349,12 +352,7 @@ def analyze(self, type: Optional[Type], node: Union[Node, SymbolTableNode],
analyzer = self.make_type_analyzer(indicator)
type.accept(analyzer)
self.check_for_omitted_generics(type)
if indicator.get('forward') or indicator.get('synthetic'):
def patch() -> None:
self.perform_transform(node,
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
node, warn)))
self.patches.append(patch)
self.generate_type_patches(node, indicator, warn)

def analyze_types(self, types: List[Type], node: Node) -> None:
# Similar to above but for nodes with multiple types.
Expand All @@ -363,12 +361,24 @@ def analyze_types(self, types: List[Type], node: Node) -> None:
analyzer = self.make_type_analyzer(indicator)
type.accept(analyzer)
self.check_for_omitted_generics(type)
self.generate_type_patches(node, indicator, warn=False)

def generate_type_patches(self,
node: Union[Node, SymbolTableNode],
indicator: Dict[str, bool],
warn: bool) -> None:
if indicator.get('forward') or indicator.get('synthetic'):
def patch() -> None:
self.perform_transform(node,
lambda tp: tp.accept(ForwardReferenceResolver(self.fail,
node, warn=False)))
self.patches.append(patch)
node, warn)))
self.patches.append((PRIORITY_FORWARD_REF, patch))
if indicator.get('typevar'):
def patch() -> None:
self.perform_transform(node,
lambda tp: tp.accept(TypeVariableChecker(self.fail)))

self.patches.append((PRIORITY_TYPEVAR_VALUES, patch))

def analyze_info(self, info: TypeInfo) -> None:
# Similar to above but for nodes with synthetic TypeInfos (NamedTuple and NewType).
Expand All @@ -387,7 +397,8 @@ def make_type_analyzer(self, indicator: Dict[str, bool]) -> TypeAnalyserPass3:
self.sem.plugin,
self.options,
self.is_typeshed_file,
indicator)
indicator,
self.patches)

def check_for_omitted_generics(self, typ: Type) -> None:
if not self.options.disallow_any_generics or self.is_typeshed_file:
Expand Down Expand Up @@ -606,3 +617,58 @@ def visit_type_type(self, t: TypeType) -> Type:
if self.check_recursion(t):
return AnyType(TypeOfAny.from_error)
return super().visit_type_type(t)


class TypeVariableChecker(TypeTranslator):
"""Visitor that checks that type variables in generic types have valid values.

Note: This must be run at the end of semantic analysis when MROs are
complete and forward references have been resolved.

This does two things:

- If type variable in C has a value restriction, check that X in C[X] conforms
to the restriction.
- If type variable in C has a non-default upper bound, check that X in C[X]
conforms to the upper bound.

(This doesn't need to be a type translator, but it simplifies the implementation.)
"""

def __init__(self, fail: Callable[[str, Context], None]) -> None:
self.fail = fail

def visit_instance(self, t: Instance) -> Type:
info = t.type
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
if tvar.values:
if isinstance(arg, TypeVarType):
arg_values = arg.values
if not arg_values:
self.fail('Type variable "{}" not valid as type '
'argument value for "{}"'.format(
arg.name, info.name()), t)
continue
else:
arg_values = [arg]
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
if not is_subtype(arg, tvar.upper_bound):
self.fail('Type argument "{}" of "{}" must be '
'a subtype of "{}"'.format(
arg, info.name(), tvar.upper_bound), t)
return t

def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
valids: List[Type], arg_number: int, context: Context) -> None:
for actual in actuals:
if (not isinstance(actual, AnyType) and
not any(is_same_type(actual, value)
for value in valids)):
if len(actuals) > 1 or not isinstance(actual, Instance):
self.fail('Invalid type argument value for "{}"'.format(
type.name()), context)
else:
class_name = '"{}"'.format(type.name())
actual_type_name = '"{}"'.format(actual.type.name())
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
arg_name, class_name, actual_type_name), context)
11 changes: 11 additions & 0 deletions mypy/semanal_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Shared definitions used by different parts of semantic analysis."""

# Priorities for ordering of patches within the final "patch" phase of semantic analysis
# (after pass 3):

# Fix forward references (needs to happen first)
PRIORITY_FORWARD_REF = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor suggestion: I would use plural PRIORITY_FORWARD_REFS to be more consistent with other constants.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make this an enum? That should be possible now that mypy no longer supports being run on 3.3.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to continue using the old idiom, since I like consistency and I don't want to change all the existing constants into enums :-)

# Fix fallbacks (does joins)
PRIORITY_FALLBACKS = 1
# Checks type var values (does subtype checks)
PRIORITY_TYPEVAR_VALUES = 2
74 changes: 16 additions & 58 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Semantic analysis of types"""

from collections import OrderedDict
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict
from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Dict, Union
from itertools import chain

from contextlib import contextmanager
Expand All @@ -14,19 +14,18 @@
Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType,
CallableType, NoneTyp, DeletedType, TypeList, TypeVarDef, TypeVisitor, SyntheticTypeVisitor,
StarType, PartialType, EllipsisType, UninhabitedType, TypeType, get_typ_args, set_typ_args,
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded
CallableArgument, get_type_vars, TypeQuery, union_items, TypeOfAny, ForwardRef, Overloaded,
TypeTranslator
)

from mypy.nodes import (
TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, TypeInfo, Context, SymbolTableNode, Var, Expression,
IndexExpr, RefExpr, nongen_builtins, check_arg_names, check_arg_kinds, ARG_POS, ARG_NAMED,
ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, FuncDef, CallExpr, NameExpr,
Decorator
Decorator, Node
)
from mypy.tvar_scope import TypeVarScope
from mypy.sametypes import is_same_type
from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError
from mypy.subtypes import is_subtype
from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext
from mypy import nodes, messages

Expand Down Expand Up @@ -656,7 +655,8 @@ def __init__(self,
plugin: Plugin,
options: Options,
is_typeshed_stub: bool,
indicator: Dict[str, bool]) -> None:
indicator: Dict[str, bool],
patches: List[Tuple[int, Callable[[], None]]]) -> None:
self.lookup_func = lookup_func
self.lookup_fqn_func = lookup_fqn_func
self.fail = fail_func
Expand All @@ -665,6 +665,7 @@ def __init__(self,
self.plugin = plugin
self.is_typeshed_stub = is_typeshed_stub
self.indicator = indicator
self.patches = patches

def visit_instance(self, t: Instance) -> None:
info = t.type
Expand Down Expand Up @@ -707,64 +708,21 @@ def visit_instance(self, t: Instance) -> None:
t.args = [AnyType(TypeOfAny.from_error) for _ in info.type_vars]
t.invalid = True
elif info.defn.type_vars:
# Check type argument values.
# TODO: Calling is_subtype and is_same_types in semantic analysis is a bad idea
for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars):
if tvar.values:
if isinstance(arg, TypeVarType):
arg_values = arg.values
if not arg_values:
self.fail('Type variable "{}" not valid as type '
'argument value for "{}"'.format(
arg.name, info.name()), t)
continue
else:
arg_values = [arg]
self.check_type_var_values(info, arg_values, tvar.name, tvar.values, i + 1, t)
# TODO: These hacks will be not necessary when this will be moved to later stage.
arg = self.resolve_type(arg)
bound = self.resolve_type(tvar.upper_bound)
if not is_subtype(arg, bound):
self.fail('Type argument "{}" of "{}" must be '
'a subtype of "{}"'.format(
arg, info.name(), bound), t)
# Check type argument values. This is postponed to the end of semantic analysis
# since we need full MROs and resolved forward references.
for tvar in info.defn.type_vars:
if (tvar.values
or not isinstance(tvar.upper_bound, Instance)
or tvar.upper_bound.type.fullname() != 'builtins.object'):
# Some restrictions on type variable. These can only be checked later
# after we have final MROs and forward references have been resolved.
self.indicator['typevar'] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated idea: indicator is a good candidate to be a TypedDict (but it is not in typing yet).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or maybe just a regular object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or maybe just a regular object.

:-) Good point!

for arg in t.args:
arg.accept(self)
if info.is_newtype:
for base in info.bases:
base.accept(self)

def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str,
valids: List[Type], arg_number: int, context: Context) -> None:
for actual in actuals:
actual = self.resolve_type(actual)
if (not isinstance(actual, AnyType) and
not any(is_same_type(actual, self.resolve_type(value))
for value in valids)):
if len(actuals) > 1 or not isinstance(actual, Instance):
self.fail('Invalid type argument value for "{}"'.format(
type.name()), context)
else:
class_name = '"{}"'.format(type.name())
actual_type_name = '"{}"'.format(actual.type.name())
self.fail(messages.INCOMPATIBLE_TYPEVAR_VALUE.format(
arg_name, class_name, actual_type_name), context)

def resolve_type(self, tp: Type) -> Type:
# This helper is only needed while is_subtype and is_same_type are
# called in third pass. This can be removed when TODO in visit_instance is fixed.
if isinstance(tp, ForwardRef):
if tp.resolved is None:
return tp.unbound
tp = tp.resolved
if isinstance(tp, Instance) and tp.type.replaced:
replaced = tp.type.replaced
if replaced.tuple_type:
tp = replaced.tuple_type
if replaced.typeddict_type:
tp = replaced.typeddict_type
return tp

def visit_callable_type(self, t: CallableType) -> None:
t.ret_type.accept(self)
for arg_type in t.arg_types:
Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-newtype.test
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,10 @@ d: object
if isinstance(d, T): # E: Cannot use isinstance() with a NewType type
reveal_type(d) # E: Revealed type is '__main__.T'
[builtins fixtures/isinstancelist.pyi]

[case testInvalidNewTypeCrash]
from typing import List, NewType, Union
N = NewType('N', XXX) # E: Argument 2 to NewType(...) must be subclassable (got "Any") \
# E: Name 'XXX' is not defined
x: List[Union[N, int]] # E: Invalid type "__main__.N"
[builtins fixtures/list.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,7 @@ T = TypeVar('T', bound='M')
class G(Generic[T]):
x: T

yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict({'x': builtins.int}, fallback=typing.Mapping[builtins.str, builtins.object])"
yb: G[int] # E: Type argument "builtins.int" of "G" must be a subtype of "TypedDict('__main__.M', {'x': builtins.int})"
yg: G[M]
z: int = G[M]().x['x']

Expand Down
7 changes: 7 additions & 0 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,10 @@ x: Union[ExtremelyLongTypeNameWhichIsGenericSoWeCanUseItMultipleTimes[int],
def takes_int(arg: int) -> None: pass

takes_int(x) # E: Argument 1 to "takes_int" has incompatible type <union: 6 items>; expected "int"

[case testRecursiveForwardReferenceInUnion]
from typing import List, Union
MYTYPE = List[Union[str, "MYTYPE"]]
[builtins fixtures/list.pyi]
[out]
main:2: error: Recursive types not fully supported yet, nested types replaced with "Any"