Skip to content

Commit

Permalink
Generalize reachability checks to support enums (#7000)
Browse files Browse the repository at this point in the history
Fixes #1803

This diff adds support for performing reachability and narrowing analysis when doing certain enum checks.

For example, given the following enum:

    class Foo(Enum):
        A = 1
        B = 2

...this pull request will make mypy do the following:

    x: Foo
    if x is Foo.A:
        reveal_type(x)  # type: Literal[Foo.A]
    elif x is Foo.B:
        reveal_type(x)  # type: Literal[Foo.B]
    else:
        reveal_type(x)  # No output: branch inferred as unreachable

This diff does not attempt to perform this same sort of narrowing for equality checks: I suspect implementing those will be harder due to their overridable nature. (E.g. you can define custom `__eq__` methods within Enum subclasses).

This pull request also finally adds support for the enum behavior [described in PEP 484][0] and also sort of partially addresses #6366

  [0]: https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
  • Loading branch information
Michael0x2a authored and ilevkivskyi committed Jul 8, 2019
1 parent 028f202 commit e818a96
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 11 deletions.
111 changes: 100 additions & 11 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import fnmatch
import sys
from contextlib import contextmanager

from typing import (
Expand Down Expand Up @@ -3535,21 +3536,34 @@ def find_isinstance_check(self, node: Expression
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
elif isinstance(node, ComparisonExpr):
# Check for `x is None` and `x is not None`.
operand_types = [coerce_to_literal(type_map[expr])
for expr in node.operands if expr in type_map]

is_not = node.operators == ['is not']
if any(is_literal_none(n) for n in node.operands) and (
is_not or node.operators == ['is']):
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
if_vars = {} # type: TypeMap
else_vars = {} # type: TypeMap
for expr in node.operands:
if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr)
and expr in type_map):

for i, expr in enumerate(node.operands):
var_type = operand_types[i]
other_type = operand_types[1 - i]

if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
# This should only be true at most once: there should be
# two elements in node.operands, and at least one of them
# should represent a None.
vartype = type_map[expr]
none_typ = [TypeRange(NoneType(), is_upper_bound=False)]
if_vars, else_vars = conditional_type_map(expr, vartype, none_typ)
# exactly two elements in node.operands and if the 'other type' is
# a singleton type, it by definition does not need to be narrowed:
# it already has the most precise type possible so does not need to
# be narrowed/included in the output map.
#
# TODO: Generalize this to handle the case where 'other_type' is
# a union of singleton types.

if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
fallback_name = other_type.fallback.type.fullname()
var_type = try_expanding_enum_to_union(var_type, fallback_name)

target_type = [TypeRange(other_type, is_upper_bound=False)]
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
break

if is_not:
Expand Down Expand Up @@ -4489,3 +4503,78 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool:
def is_private(node_name: str) -> bool:
"""Check if node is private to class definition."""
return node_name.startswith('__') and not node_name.endswith('__')


def is_singleton_type(typ: Type) -> bool:
"""Returns 'true' if this type is a "singleton type" -- if there exists
exactly only one runtime value associated with this type.
That is, given two values 'a' and 'b' that have the same type 't',
'is_singleton_type(t)' returns True if and only if the expression 'a is b' is
always true.
Currently, this returns True when given NoneTypes and enum LiteralTypes.
Note that other kinds of LiteralTypes cannot count as singleton types. For
example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed
that 'a is b' will always be true -- some implementations of Python will end up
constructing two distinct instances of 100001.
"""
# TODO: Also make this return True if the type is a bool LiteralType.
# Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented?
return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal())


def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type:
"""Attempts to recursively expand any enum Instances with the given target_fullname
into a Union of all of its component LiteralTypes.
For example, if we have:
class Color(Enum):
RED = 1
BLUE = 2
YELLOW = 3
class Status(Enum):
SUCCESS = 1
FAILURE = 2
UNKNOWN = 3
...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`,
this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status].
"""
if isinstance(typ, UnionType):
new_items = [try_expanding_enum_to_union(item, target_fullname)
for item in typ.items]
return UnionType.make_simplified_union(new_items)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
if not isinstance(symbol.node, Var):
continue
new_items.append(LiteralType(name, typ))
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
# insertion order only starting with Python 3.7. So, we sort these for older
# versions of Python to help make tests deterministic.
#
# We could probably skip the sort for Python 3.6 since people probably run mypy
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return UnionType.make_simplified_union(new_items)
else:
return typ


def coerce_to_literal(typ: Type) -> Type:
"""Recursively converts any Instances that have a last_known_value into the
corresponding LiteralType.
"""
if isinstance(typ, UnionType):
new_items = [coerce_to_literal(item) for item in typ.items]
return UnionType.make_simplified_union(new_items)
elif isinstance(typ, Instance) and typ.last_known_value:
return typ.last_known_value
else:
return typ
244 changes: 244 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,247 @@ class SomeEnum(Enum):
main:2: note: Revealed type is 'builtins.int'
[out2]
main:2: note: Revealed type is 'builtins.str'

[case testEnumReachabilityChecksBasic]
from enum import Enum
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

x: Literal[Foo.A, Foo.B, Foo.C]
if x is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif x is Foo.B:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif x is Foo.C:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

if Foo.A is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

y: Foo
if y is Foo.A:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif y is Foo.B:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif y is Foo.C:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable

if Foo.A is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksIndirect]
from enum import Enum
from typing_extensions import Literal, Final

class Foo(Enum):
A = 1
B = 2
C = 3

def accepts_foo_a(x: Literal[Foo.A]) -> None: ...

x: Foo
y: Literal[Foo.A]
z: Final = Foo.A

if x is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
if y is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'

if x is z:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)
if z is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)

if y is z:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(y) # No output: this branch is unreachable
reveal_type(z) # No output: this branch is unreachable
if z is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is '__main__.Foo'
accepts_foo_a(z)
else:
reveal_type(y) # No output: this branch is unreachable
reveal_type(z) # No output: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityNoNarrowingForUnionMessiness]
from enum import Enum
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

x: Foo
y: Literal[Foo.A, Foo.B]
z: Literal[Foo.B, Foo.C]

# For the sake of simplicity, no narrowing is done when the narrower type is a Union.
if x is y:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'

if y is z:
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
else:
reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]'
reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithNone]
# flags: --strict-optional
from enum import Enum
from typing import Optional

class Foo(Enum):
A = 1
B = 2
C = 3

x: Optional[Foo]
if x:
reveal_type(x) # N: Revealed type is '__main__.Foo'
else:
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'

if x is not None:
reveal_type(x) # N: Revealed type is '__main__.Foo'
else:
reveal_type(x) # N: Revealed type is 'None'

if x is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithMultipleEnums]
from enum import Enum
from typing import Union
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
class Bar(Enum):
A = 1
B = 2

x1: Union[Foo, Bar]
if x1 is Foo.A:
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'

x2: Union[Foo, Bar]
if x2 is Bar.A:
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
else:
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'

x3: Union[Foo, Bar]
if x3 is Foo.A or x3 is Bar.A:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
else:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'

[builtins fixtures/bool.pyi]

[case testEnumReachabilityPEP484Example1]
# flags: --strict-optional
from typing import Union
from typing_extensions import Final
from enum import Enum

class Empty(Enum):
token = 0
_empty: Final = Empty.token

def func(x: Union[int, None, Empty] = _empty) -> int:
boom = x + 42 # E: Unsupported left operand type for + ("None") \
# E: Unsupported left operand type for + ("Empty") \
# N: Left operand is of type "Union[int, None, Empty]"
if x is _empty:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
return 0
elif x is None:
reveal_type(x) # N: Revealed type is 'None'
return 1
else: # At this point typechecker knows that x can only have type int
reveal_type(x) # N: Revealed type is 'builtins.int'
return x + 2
[builtins fixtures/primitives.pyi]

[case testEnumReachabilityPEP484Example2]
from typing import Union
from enum import Enum

class Reason(Enum):
timeout = 1
error = 2

def process(response: Union[str, Reason] = '') -> str:
if response is Reason.timeout:
reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.timeout]'
return 'TIMEOUT'
elif response is Reason.error:
reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.error]'
return 'ERROR'
else:
# response can be only str, all other possible values exhausted
reveal_type(response) # N: Revealed type is 'builtins.str'
return 'PROCESSED: ' + response

[builtins fixtures/primitives.pyi]

0 comments on commit e818a96

Please sign in to comment.