Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize reachability checks to support enums #7000

Merged
merged 8 commits into from
Jul 8, 2019
106 changes: 95 additions & 11 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import fnmatch
import sys
from contextlib import contextmanager

from typing import (
Expand Down Expand Up @@ -3487,21 +3488,34 @@ def find_isinstance_check(self, node: Expression
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
elif isinstance(node, ComparisonExpr):
# Check for `x is None` and `x is not None`.
operand_types = [coerce_to_literal(type_map[expr])
for expr in node.operands if expr in type_map]

is_not = node.operators == ['is not']
if any(is_literal_none(n) for n in node.operands) and (
is_not or node.operators == ['is']):
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would be it be to do exactly the same for ==? (Mostly so that example in #4223 will not give the false positive.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's slightly trickier -- the semantics of a is SomeEnum.Foo and a == SomeEnum.Foo are different, unfortunately.

If a is something like an int or some other unrelated type, we know the first expression will always be False. But for the second, we have no idea since a could have defined a custom __eq__ function. SomeEnum itself could also have defined/inherited a custom __eq__ method, which would further complicate things.

I'll submit a separate PR for this: it ended up being easier to make this change if I also added support for chained operator comparisons at the same time (see below).

if_vars = {} # type: TypeMap
else_vars = {} # type: TypeMap
for expr in node.operands:
if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr)
and expr in type_map):

for i, expr in enumerate(node.operands):
var_type = operand_types[i]
other_type = operand_types[1 - i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a mystery to me. What if one has if a is b is c or even more operands?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've actually never properly handled this case, I think. The old if-check on line 3490-ish lets this code run only if there's exactly only a single operator; the new if-check I'm replacing that with continues to do the same thing. So, as a consequence, we can safely assume there'll be exactly two operands at this point.

I have a fix for this, but I decided it might be better to submit it as a separate PR. Once I combined this with the equality changes mentioned above, the changes ended up being much more intrusive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old if-check on line 3490-ish lets this code run only if there's exactly only a single operator

Maybe I am missing something, but the code there looks like it is about completely different case, it is about isinstance() and issubclass() having other number of arguments than two.

(Also it is in a different if branch, so it will not affect this branch).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, maybe the line numbers shifted after I merged. It's now line 3541-ish.

The old check used to do this:

is_not = node.operators == ['is not']
if any(is_literal_none(n) for n in node.operands) and (
                is_not or node.operators == ['is']):

And the new checks do this:

is_not = node.operators == ['is not']
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):

We also make the same assumption when handling the == and in operators as well -- those are:

elif node.operators == ['==']:

and:

elif node.operators in [['in'], ['not in']]:


if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
# This should only be true at most once: there should be
# two elements in node.operands, and at least one of them
# should represent a None.
vartype = type_map[expr]
none_typ = [TypeRange(NoneType(), is_upper_bound=False)]
if_vars, else_vars = conditional_type_map(expr, vartype, none_typ)
# exactly two elements in node.operands and if the 'other type' is
# a singleton type, it by definition does not need to be narrowed:
# it already has the most precise type possible so does not need to
# be narrowed/included in the output map.
#
# TODO: Generalize this to handle the case where 'other_type' is
# a union of singleton types.

if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
fallback_name = other_type.fallback.type.fullname()
var_type = try_expanding_enum_to_union(var_type, fallback_name)

target_type = [TypeRange(other_type, is_upper_bound=False)]
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
break

if is_not:
Expand Down Expand Up @@ -4438,3 +4452,73 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool:
def is_private(node_name: str) -> bool:
"""Check if node is private to class definition."""
return node_name.startswith('__') and not node_name.endswith('__')


def is_singleton_type(typ: Type) -> bool:
"""Returns 'true' if this type is a "singleton type" -- if there exists
exactly only one runtime value associated with this type.

That is, given two values 'a' and 'b' that have the same type 't',
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
always true.

Currently, this returns True when given NoneTypes and enum LiteralTypes.

Note that other kinds of LiteralTypes cannot count as singleton types. For
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
that 'a is b' will always be true -- some implementations of Python will end up
constructing two distinct instances of 100001.
"""
# TODO: Also make this return True if the type is a bool LiteralType.
# Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented?
return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal())


def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type:
"""Attempts to recursively any enum Instances with the given target_fullname
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
into a Union of all of its component LiteralTypes.

For example, if we have:

class Color(Enum):
RED = 1
BLUE = 2
YELLOW = 3

...and if we call `try_expanding_enum_to_union(color_instance, 'module.Color')`,
this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW].
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
"""
if isinstance(typ, UnionType):
new_items = [try_expanding_enum_to_union(item, target_fullname)
for item in typ.items]
return UnionType.make_simplified_union(new_items)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return UnionType.make_simplified_union(new_items)
else:
return typ


def coerce_to_literal(typ: Type) -> Type:
"""Recursively converts any Instances that have a last_known_value into the
corresponding LiteralType.
"""
if isinstance(typ, UnionType):
new_items = [coerce_to_literal(item) for item in typ.items]
return UnionType.make_simplified_union(new_items)
elif isinstance(typ, Instance) and typ.last_known_value:
return typ.last_known_value
else:
return typ
244 changes: 244 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,247 @@ class SomeEnum(Enum):
main:2: error: Revealed type is 'builtins.int'
[out2]
main:2: error: Revealed type is 'builtins.str'

[case testEnumReachabilityChecksBasic]
from enum import Enum
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

x: Literal[Foo.A, Foo.B, Foo.C]
if x is Foo.A:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
elif x is Foo.B:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.B]'
elif x is Foo.C:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

if Foo.A is x:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B is x:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C is x:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

y: Foo
if y is Foo.A:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
elif y is Foo.B:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.B]'
elif y is Foo.C:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable

if Foo.A is y:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B is y:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C is y:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksIndirect]
from enum import Enum
from typing_extensions import Literal, Final

class Foo(Enum):
A = 1
B = 2
C = 3

def accepts_foo_a(x: Literal[Foo.A]) -> None: ...

x: Foo
y: Literal[Foo.A]
z: Final = Foo.A

if x is y:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
if y is x:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'

if x is z:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)
if z is x:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)

if y is z:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(y) # No output: this branch is unreachable
reveal_type(z) # No output: this branch is unreachable
if z is y:
reveal_type(y) # E: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # E: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(y) # No output: this branch is unreachable
reveal_type(z) # No output: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityNoNarrowingForUnionMessiness]
from enum import Enum
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

x: Foo
y: Literal[Foo.A, Foo.B]
z: Literal[Foo.B, Foo.C]

# For the sake of simplicity, no narrowing is done when the narrower type is a Union.
if x is y:
reveal_type(x) # E: Revealed type is '__main__.Foo'
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
else:
reveal_type(x) # E: Revealed type is '__main__.Foo'
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'

if y is z:
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
else:
reveal_type(y) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithNone]
# flags: --strict-optional
from enum import Enum
from typing import Optional

class Foo(Enum):
A = 1
B = 2
C = 3

x: Optional[Foo]
if x:
reveal_type(x) # E: Revealed type is '__main__.Foo'
else:
reveal_type(x) # E: Revealed type is 'Union[__main__.Foo, None]'

if x is not None:
reveal_type(x) # E: Revealed type is '__main__.Foo'
else:
reveal_type(x) # E: Revealed type is 'None'

if x is Foo.A:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithMultipleEnums]
from enum import Enum
from typing import Union
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
class Bar(Enum):
A = 1
B = 2

x1: Union[Foo, Bar]
if x1 is Foo.A:
reveal_type(x1) # E: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x1) # E: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'

x2: Union[Foo, Bar]
if x2 is Bar.A:
reveal_type(x2) # E: Revealed type is 'Literal[__main__.Bar.A]'
else:
reveal_type(x2) # E: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'

x3: Union[Foo, Bar]
if x3 is Foo.A or x3 is Bar.A:
reveal_type(x3) # E: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
else:
reveal_type(x3) # E: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'

[builtins fixtures/bool.pyi]

[case testEnumReachabilityPEP484Example1]
# flags: --strict-optional
from typing import Union
from typing_extensions import Final
from enum import Enum

class Empty(Enum):
token = 0
_empty: Final = Empty.token

def func(x: Union[int, None, Empty] = _empty) -> int:
boom = x + 42 # E: Unsupported left operand type for + ("None") \
# E: Unsupported left operand type for + ("Empty") \
# N: Left operand is of type "Union[int, None, Empty]"
if x is _empty:
reveal_type(x) # E: Revealed type is 'Literal[__main__.Empty.token]'
return 0
elif x is None:
reveal_type(x) # E: Revealed type is 'None'
return 1
else: # At this point typechecker knows that x can only have type int
reveal_type(x) # E: Revealed type is 'builtins.int'
return x + 2
[builtins fixtures/primitives.pyi]

[case testEnumReachabilityPEP484Example2]
from typing import Union
from enum import Enum

class Reason(Enum):
timeout = 1
error = 2

def process(response: Union[str, Reason] = '') -> str:
if response is Reason.timeout:
reveal_type(response) # E: Revealed type is 'Literal[__main__.Reason.timeout]'
return 'TIMEOUT'
elif response is Reason.error:
reveal_type(response) # E: Revealed type is 'Literal[__main__.Reason.error]'
return 'ERROR'
else:
# response can be only str, all other possible values exhausted
reveal_type(response) # E: Revealed type is 'builtins.str'
return 'PROCESSED: ' + response

[builtins fixtures/primitives.pyi]