From e818a96917a5111c4a5b901d49aa6b36de9700d3 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 8 Jul 2019 08:23:54 -0700 Subject: [PATCH] Generalize reachability checks to support enums (#7000) Fixes https://github.com/python/mypy/issues/1803 This diff adds support for performing reachability and narrowing analysis when doing certain enum checks. For example, given the following enum: class Foo(Enum): A = 1 B = 2 ...this pull request will make mypy do the following: x: Foo if x is Foo.A: reveal_type(x) # type: Literal[Foo.A] elif x is Foo.B: reveal_type(x) # type: Literal[Foo.B] else: reveal_type(x) # No output: branch inferred as unreachable This diff does not attempt to perform this same sort of narrowing for equality checks: I suspect implementing those will be harder due to their overridable nature. (E.g. you can define custom `__eq__` methods within Enum subclasses). This pull request also finally adds support for the enum behavior [described in PEP 484][0] and also sort of partially addresses https://github.com/python/mypy/issues/6366 [0]: https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions --- mypy/checker.py | 111 +++++++++++++-- test-data/unit/check-enum.test | 244 +++++++++++++++++++++++++++++++++ 2 files changed, 344 insertions(+), 11 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 45a8b2e24bb9..9ea3a9b902d9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,6 +2,7 @@ import itertools import fnmatch +import sys from contextlib import contextmanager from typing import ( @@ -3535,21 +3536,34 @@ def find_isinstance_check(self, node: Expression vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - # Check for `x is None` and `x is not None`. + operand_types = [coerce_to_literal(type_map[expr]) + for expr in node.operands if expr in type_map] + is_not = node.operators == ['is not'] - if any(is_literal_none(n) for n in node.operands) and ( - is_not or node.operators == ['is']): + if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): if_vars = {} # type: TypeMap else_vars = {} # type: TypeMap - for expr in node.operands: - if (literal(expr) == LITERAL_TYPE and not is_literal_none(expr) - and expr in type_map): + + for i, expr in enumerate(node.operands): + var_type = operand_types[i] + other_type = operand_types[1 - i] + + if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): # This should only be true at most once: there should be - # two elements in node.operands, and at least one of them - # should represent a None. - vartype = type_map[expr] - none_typ = [TypeRange(NoneType(), is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, vartype, none_typ) + # exactly two elements in node.operands and if the 'other type' is + # a singleton type, it by definition does not need to be narrowed: + # it already has the most precise type possible so does not need to + # be narrowed/included in the output map. + # + # TODO: Generalize this to handle the case where 'other_type' is + # a union of singleton types. + + if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): + fallback_name = other_type.fallback.type.fullname() + var_type = try_expanding_enum_to_union(var_type, fallback_name) + + target_type = [TypeRange(other_type, is_upper_bound=False)] + if_vars, else_vars = conditional_type_map(expr, var_type, target_type) break if is_not: @@ -4489,3 +4503,78 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool: def is_private(node_name: str) -> bool: """Check if node is private to class definition.""" return node_name.startswith('__') and not node_name.endswith('__') + + +def is_singleton_type(typ: Type) -> bool: + """Returns 'true' if this type is a "singleton type" -- if there exists + exactly only one runtime value associated with this type. + + That is, given two values 'a' and 'b' that have the same type 't', + 'is_singleton_type(t)' returns True if and only if the expression 'a is b' is + always true. + + Currently, this returns True when given NoneTypes and enum LiteralTypes. + + Note that other kinds of LiteralTypes cannot count as singleton types. For + example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed + that 'a is b' will always be true -- some implementations of Python will end up + constructing two distinct instances of 100001. + """ + # TODO: Also make this return True if the type is a bool LiteralType. + # Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented? + return isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) + + +def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> Type: + """Attempts to recursively expand any enum Instances with the given target_fullname + into a Union of all of its component LiteralTypes. + + For example, if we have: + + class Color(Enum): + RED = 1 + BLUE = 2 + YELLOW = 3 + + class Status(Enum): + SUCCESS = 1 + FAILURE = 2 + UNKNOWN = 3 + + ...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`, + this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status]. + """ + if isinstance(typ, UnionType): + new_items = [try_expanding_enum_to_union(item, target_fullname) + for item in typ.items] + return UnionType.make_simplified_union(new_items) + elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: + new_items = [] + for name, symbol in typ.type.names.items(): + if not isinstance(symbol.node, Var): + continue + new_items.append(LiteralType(name, typ)) + # SymbolTables are really just dicts, and dicts are guaranteed to preserve + # insertion order only starting with Python 3.7. So, we sort these for older + # versions of Python to help make tests deterministic. + # + # We could probably skip the sort for Python 3.6 since people probably run mypy + # only using CPython, but we might as well for the sake of full correctness. + if sys.version_info < (3, 7): + new_items.sort(key=lambda lit: lit.value) + return UnionType.make_simplified_union(new_items) + else: + return typ + + +def coerce_to_literal(typ: Type) -> Type: + """Recursively converts any Instances that have a last_known_value into the + corresponding LiteralType. + """ + if isinstance(typ, UnionType): + new_items = [coerce_to_literal(item) for item in typ.items] + return UnionType.make_simplified_union(new_items) + elif isinstance(typ, Instance) and typ.last_known_value: + return typ.last_known_value + else: + return typ diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 431f0c9b241f..9f015f24986c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -610,3 +610,247 @@ class SomeEnum(Enum): main:2: note: Revealed type is 'builtins.int' [out2] main:2: note: Revealed type is 'builtins.str' + +[case testEnumReachabilityChecksBasic] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Literal[Foo.A, Foo.B, Foo.C] +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif x is Foo.C: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +if Foo.A is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +y: Foo +if y is Foo.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif y is Foo.B: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif y is Foo.C: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable + +if Foo.A is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityChecksIndirect] +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +def accepts_foo_a(x: Literal[Foo.A]) -> None: ... + +x: Foo +y: Literal[Foo.A] +z: Final = Foo.A + +if x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +if y is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + +if x is z: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +if z is x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) + +if y is z: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(y) # No output: this branch is unreachable + reveal_type(z) # No output: this branch is unreachable +if z is y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(z) # N: Revealed type is '__main__.Foo' + accepts_foo_a(z) +else: + reveal_type(y) # No output: this branch is unreachable + reveal_type(z) # No output: this branch is unreachable +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityNoNarrowingForUnionMessiness] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Foo +y: Literal[Foo.A, Foo.B] +z: Literal[Foo.B, Foo.C] + +# For the sake of simplicity, no narrowing is done when the narrower type is a Union. +if x is y: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + +if y is z: + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +else: + reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' + reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityWithNone] +# flags: --strict-optional +from enum import Enum +from typing import Optional + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Optional[Foo] +if x: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' + +if x is not None: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is 'None' + +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]' +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityWithMultipleEnums] +from enum import Enum +from typing import Union +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 +class Bar(Enum): + A = 1 + B = 2 + +x1: Union[Foo, Bar] +if x1 is Foo.A: + reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]' + +x2: Union[Foo, Bar] +if x2 is Bar.A: + reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]' +else: + reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]' + +x3: Union[Foo, Bar] +if x3 is Foo.A or x3 is Bar.A: + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]' +else: + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]' + +[builtins fixtures/bool.pyi] + +[case testEnumReachabilityPEP484Example1] +# flags: --strict-optional +from typing import Union +from typing_extensions import Final +from enum import Enum + +class Empty(Enum): + token = 0 +_empty: Final = Empty.token + +def func(x: Union[int, None, Empty] = _empty) -> int: + boom = x + 42 # E: Unsupported left operand type for + ("None") \ + # E: Unsupported left operand type for + ("Empty") \ + # N: Left operand is of type "Union[int, None, Empty]" + if x is _empty: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + return 0 + elif x is None: + reveal_type(x) # N: Revealed type is 'None' + return 1 + else: # At this point typechecker knows that x can only have type int + reveal_type(x) # N: Revealed type is 'builtins.int' + return x + 2 +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityPEP484Example2] +from typing import Union +from enum import Enum + +class Reason(Enum): + timeout = 1 + error = 2 + +def process(response: Union[str, Reason] = '') -> str: + if response is Reason.timeout: + reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.timeout]' + return 'TIMEOUT' + elif response is Reason.error: + reveal_type(response) # N: Revealed type is 'Literal[__main__.Reason.error]' + return 'ERROR' + else: + # response can be only str, all other possible values exhausted + reveal_type(response) # N: Revealed type is 'builtins.str' + return 'PROCESSED: ' + response + +[builtins fixtures/primitives.pyi]