Skip to content

Commit

Permalink
Fix crash on TypeGuard plus "and" (#10496)
Browse files Browse the repository at this point in the history
In python/typeshed#5473, I tried to switch a number of `inspect` functions to use the new `TypeGuard` functionality. Unfortunately, mypy-primer found a number of crashes in third-party libraries in places where a TypeGuard function was ANDed together with some other check. Examples:

- https://github.com/sphinx-doc/sphinx/blob/4.x/sphinx/util/inspect.py#L252
- https://github.com/sphinx-doc/sphinx/blob/4.x/sphinx/ext/coverage.py#L212
- https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/elements/doc_string.py#L105

The problems trace back to the decision in #9865 to make TypeGuardType not inherit from ProperType: in various conditions that are more complicated than a simple `if` check, mypy wants everything to become a ProperType. Therefore, to fix the crashes I had to make TypeGuardType a ProperType and support it in various visitors.
  • Loading branch information
JelleZijlstra authored May 21, 2021
1 parent de6fd6a commit 8e909e4
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 21 deletions.
5 changes: 4 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
ProperType, get_proper_type, TypeAliasType
ProperType, get_proper_type, TypeAliasType, TypeGuardType
)
from mypy.maptype import map_instance_to_supertype
import mypy.subtypes
Expand Down Expand Up @@ -534,6 +534,9 @@ def visit_union_type(self, template: UnionType) -> List[Constraint]:
def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]:
assert False, "This should be never called, got {}".format(template)

def visit_type_guard_type(self, template: TypeGuardType) -> List[Constraint]:
assert False, "This should be never called, got {}".format(template)

def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]:
res = [] # type: List[Constraint]
for t in types:
Expand Down
5 changes: 4 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
get_proper_type, TypeAliasType
get_proper_type, TypeAliasType, TypeGuardType
)
from mypy.nodes import ARG_STAR, ARG_STAR2

Expand Down Expand Up @@ -90,6 +90,9 @@ def visit_union_type(self, t: UnionType) -> ProperType:
from mypy.typeops import make_simplified_union
return make_simplified_union(erased_items)

def visit_type_guard_type(self, t: TypeGuardType) -> ProperType:
return TypeGuardType(t.type_guard.accept(self))

def visit_type_type(self, t: TypeType) -> ProperType:
return TypeType.make_normalized(t.item.accept(self), line=t.line)

Expand Down
5 changes: 4 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Iterable, List, TypeVar, Mapping, cast

from mypy.types import (
Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType,
Type, Instance, CallableType, TypeGuardType, TypeVisitor, UnboundType, AnyType,
NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType,
ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId,
FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType,
Expand Down Expand Up @@ -126,6 +126,9 @@ def visit_union_type(self, t: UnionType) -> Type:
from mypy.typeops import make_simplified_union # asdf
return make_simplified_union(self.expand_types(t.items), t.line, t.column)

def visit_type_guard_type(self, t: TypeGuardType) -> ProperType:
return TypeGuardType(t.type_guard.accept(self))

def visit_partial_type(self, t: PartialType) -> Type:
return t

Expand Down
5 changes: 4 additions & 1 deletion mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
TypeVarExpr, ClassDef, Block, TypeAlias,
)
from mypy.types import (
CallableType, Instance, Overloaded, TupleType, TypedDictType,
CallableType, Instance, Overloaded, TupleType, TypeGuardType, TypedDictType,
TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType,
TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, TypeVarDef
)
Expand Down Expand Up @@ -254,6 +254,9 @@ def visit_union_type(self, ut: UnionType) -> None:
for it in ut.items:
it.accept(self)

def visit_type_guard_type(self, t: TypeGuardType) -> None:
t.type_guard.accept(self)

def visit_void(self, o: Any) -> None:
pass # Nothing to descend into.

Expand Down
3 changes: 3 additions & 0 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def visit_literal_type(self, t: types.LiteralType) -> Set[str]:
def visit_union_type(self, t: types.UnionType) -> Set[str]:
return self._visit(t.items)

def visit_type_guard_type(self, t: types.TypeGuardType) -> Set[str]:
return self._visit(t.type_guard)

def visit_partial_type(self, t: types.PartialType) -> Set[str]:
return set()

Expand Down
5 changes: 4 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType,
TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type,
ProperType, get_proper_types, TypeAliasType, PlaceholderType
ProperType, get_proper_types, TypeAliasType, PlaceholderType, TypeGuardType
)
from mypy.maptype import map_instance_to_supertype
from mypy.subtypes import (
Expand Down Expand Up @@ -340,6 +340,9 @@ def visit_type_type(self, t: TypeType) -> ProperType:
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, "This should be never called, got {}".format(t)

def visit_type_guard_type(self, t: TypeGuardType) -> ProperType:
assert False, "This should be never called, got {}".format(t)

def join(self, s: Type, t: Type) -> ProperType:
return join_types(s, t)

Expand Down
5 changes: 4 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType,
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType,
ProperType, get_proper_type, get_proper_types, TypeAliasType
ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardType
)
from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype
from mypy.erasetype import erase_type
Expand Down Expand Up @@ -648,6 +648,9 @@ def visit_type_type(self, t: TypeType) -> ProperType:
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, "This should be never called, got {}".format(t)

def visit_type_guard_type(self, t: TypeGuardType) -> ProperType:
assert False, "This should be never called, got {}".format(t)

def meet(self, s: Type, t: Type) -> ProperType:
return meet_types(s, t)

Expand Down
9 changes: 8 additions & 1 deletion mypy/sametypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Sequence

from mypy.types import (
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
Type, TypeGuardType, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
ProperType, get_proper_type, TypeAliasType)
Expand All @@ -10,6 +10,7 @@

def is_same_type(left: Type, right: Type) -> bool:
"""Is 'left' the same type as 'right'?"""

left = get_proper_type(left)
right = get_proper_type(right)

Expand Down Expand Up @@ -150,6 +151,12 @@ def visit_union_type(self, left: UnionType) -> bool:
else:
return False

def visit_type_guard_type(self, left: TypeGuardType) -> bool:
if isinstance(self.right, TypeGuardType):
return is_same_type(left.type_guard, self.right.type_guard)
else:
return False

def visit_overloaded(self, left: Overloaded) -> bool:
if isinstance(self.right, Overloaded):
return is_same_types(left.items(), self.right.items())
Expand Down
5 changes: 4 additions & 1 deletion mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
FuncBase, OverloadedFuncDef, FuncItem, MypyFile, UNBOUND_IMPORTED
)
from mypy.types import (
Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType,
Type, TypeGuardType, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType,
ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType,
UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType
)
Expand Down Expand Up @@ -335,6 +335,9 @@ def visit_union_type(self, typ: UnionType) -> SnapshotItem:
normalized = tuple(sorted(items))
return ('UnionType', normalized)

def visit_type_guard_type(self, typ: TypeGuardType) -> SnapshotItem:
return ('TypeGuardType', snapshot_type(typ.type_guard))

def visit_overloaded(self, typ: Overloaded) -> SnapshotItem:
return ('Overloaded', snapshot_types(typ.items()))

Expand Down
5 changes: 4 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType,
TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType,
Overloaded, TypeVarDef, TypeList, CallableArgument, EllipsisType, StarType, LiteralType,
RawExpressionType, PartialType, PlaceholderType, TypeAliasType
RawExpressionType, PartialType, PlaceholderType, TypeAliasType, TypeGuardType
)
from mypy.util import get_prefix, replace_object_state
from mypy.typestate import TypeState
Expand Down Expand Up @@ -389,6 +389,9 @@ def visit_erased_type(self, t: ErasedType) -> None:
def visit_deleted_type(self, typ: DeletedType) -> None:
pass

def visit_type_guard_type(self, typ: TypeGuardType) -> None:
raise RuntimeError

def visit_partial_type(self, typ: PartialType) -> None:
raise RuntimeError

Expand Down
6 changes: 5 additions & 1 deletion mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType,
TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType,
FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType,
TypeAliasType)
TypeAliasType, TypeGuardType
)
from mypy.server.trigger import make_trigger, make_wildcard_trigger
from mypy.util import correct_relative_import
from mypy.scope import Scope
Expand Down Expand Up @@ -970,6 +971,9 @@ def visit_unbound_type(self, typ: UnboundType) -> List[str]:
def visit_uninhabited_type(self, typ: UninhabitedType) -> List[str]:
return []

def visit_type_guard_type(self, typ: TypeGuardType) -> List[str]:
return typ.type_guard.accept(self)

def visit_union_type(self, typ: UnionType) -> List[str]:
triggers = []
for item in typ.items:
Expand Down
13 changes: 12 additions & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing_extensions import Final

from mypy.types import (
Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType,
Type, AnyType, TypeGuardType, UnboundType, TypeVisitor, FormalArgument, NoneType,
Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded,
ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance,
FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType
Expand Down Expand Up @@ -475,6 +475,9 @@ def visit_overloaded(self, left: Overloaded) -> bool:
def visit_union_type(self, left: UnionType) -> bool:
return all(self._is_subtype(item, self.orig_right) for item in left.items)

def visit_type_guard_type(self, left: TypeGuardType) -> bool:
raise RuntimeError("TypeGuard should not appear here")

def visit_partial_type(self, left: PartialType) -> bool:
# This is indeterminate as we don't really know the complete type yet.
raise RuntimeError
Expand Down Expand Up @@ -1374,6 +1377,14 @@ def visit_overloaded(self, left: Overloaded) -> bool:
def visit_union_type(self, left: UnionType) -> bool:
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])

def visit_type_guard_type(self, left: TypeGuardType) -> bool:
if isinstance(self.right, TypeGuardType):
# TypeGuard[bool] is a subtype of TypeGuard[int]
return self._is_proper_subtype(left.type_guard, self.right.type_guard)
else:
# TypeGuards aren't a subtype of anything else for now (but see #10489)
return False

def visit_partial_type(self, left: PartialType) -> bool:
# TODO: What's the right thing to do here?
return False
Expand Down
12 changes: 11 additions & 1 deletion mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
T = TypeVar('T')

from mypy.types import (
Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType,
Type, AnyType, CallableType, Overloaded, TupleType, TypeGuardType, TypedDictType, LiteralType,
RawExpressionType, Instance, NoneType, TypeType,
UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeDef,
UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument,
Expand Down Expand Up @@ -103,6 +103,10 @@ def visit_type_type(self, t: TypeType) -> T:
def visit_type_alias_type(self, t: TypeAliasType) -> T:
pass

@abstractmethod
def visit_type_guard_type(self, t: TypeGuardType) -> T:
pass


@trait
@mypyc_attr(allow_interpreted_subclasses=True)
Expand Down Expand Up @@ -220,6 +224,9 @@ def visit_union_type(self, t: UnionType) -> Type:
def translate_types(self, types: Iterable[Type]) -> List[Type]:
return [t.accept(self) for t in types]

def visit_type_guard_type(self, t: TypeGuardType) -> Type:
return TypeGuardType(t.type_guard.accept(self))

def translate_variables(self,
variables: Sequence[TypeVarLikeDef]) -> Sequence[TypeVarLikeDef]:
return variables
Expand Down Expand Up @@ -319,6 +326,9 @@ def visit_star_type(self, t: StarType) -> T:
def visit_union_type(self, t: UnionType) -> T:
return self.query_types(t.items)

def visit_type_guard_type(self, t: TypeGuardType) -> T:
return t.type_guard.accept(self)

def visit_overloaded(self, t: Overloaded) -> T:
return self.query_types(t.items())

Expand Down
5 changes: 4 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy.types import (
Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType,
CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor,
StarType, PartialType, EllipsisType, UninhabitedType, TypeType,
StarType, PartialType, EllipsisType, UninhabitedType, TypeType, TypeGuardType,
CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType,
PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeDef, ParamSpecDef
)
Expand Down Expand Up @@ -542,6 +542,9 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
)
return ret

def visit_type_guard_type(self, t: TypeGuardType) -> Type:
return t

def anal_type_guard(self, t: Type) -> Optional[Type]:
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
Expand Down
20 changes: 13 additions & 7 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,14 @@ def copy_modified(self, *,
self.line, self.column)


class TypeGuardType(Type):
class ProperType(Type):
"""Not a type alias.
Every type except TypeAliasType must inherit from this type.
"""


class TypeGuardType(ProperType):
"""Only used by find_instance_check() etc."""
def __init__(self, type_guard: Type):
super().__init__(line=type_guard.line, column=type_guard.column)
Expand All @@ -279,12 +286,8 @@ def __init__(self, type_guard: Type):
def __repr__(self) -> str:
return "TypeGuard({})".format(self.type_guard)


class ProperType(Type):
"""Not a type alias.
Every type except TypeAliasType must inherit from this type.
"""
def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_type_guard_type(self)


class TypeVarId:
Expand Down Expand Up @@ -2183,6 +2186,9 @@ def visit_union_type(self, t: UnionType) -> str:
s = self.list_str(t.items)
return 'Union[{}]'.format(s)

def visit_type_guard_type(self, t: TypeGuardType) -> str:
return 'TypeGuard[{}]'.format(t.type_guard.accept(self))

def visit_partial_type(self, t: PartialType) -> str:
if t.type is None:
return '<partial None>'
Expand Down
5 changes: 4 additions & 1 deletion mypy/typetraverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType,
TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType,
Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType,
PlaceholderType, PartialType, RawExpressionType, TypeAliasType
PlaceholderType, PartialType, RawExpressionType, TypeAliasType, TypeGuardType
)


Expand Down Expand Up @@ -62,6 +62,9 @@ def visit_typeddict_type(self, t: TypedDictType) -> None:
def visit_union_type(self, t: UnionType) -> None:
self.traverse_types(t.items)

def visit_type_guard_type(self, t: TypeGuardType) -> None:
t.type_guard.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
self.traverse_types(t.items())

Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,24 @@ class C:
class D(C):
def is_float(self, a: object) -> bool: pass # E: Signature of "is_float" incompatible with supertype "C"
[builtins fixtures/tuple.pyi]

[case testTypeGuardInAnd]
from typing import Any
from typing_extensions import TypeGuard
import types
def isclass(a: object) -> bool:
pass
def ismethod(a: object) -> TypeGuard[float]:
pass
def isfunction(a: object) -> TypeGuard[str]:
pass
def isclassmethod(obj: Any) -> bool:
if ismethod(obj) and obj.__self__ is not None and isclass(obj.__self__): # E: "float" has no attribute "__self__"
return True

return False
def coverage(obj: Any) -> bool:
if not (ismethod(obj) or isfunction(obj)):
return True
return False
[builtins fixtures/classmethod.pyi]

0 comments on commit 8e909e4

Please sign in to comment.