diff --git a/mypy/checker.py b/mypy/checker.py index 9ea3a9b902d93..a7281cb3b40be 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - operand_types = [coerce_to_literal(type_map[expr]) - for expr in node.operands if expr in type_map] - - is_not = node.operators == ['is not'] - if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): - if_vars = {} # type: TypeMap - else_vars = {} # type: TypeMap - - for i, expr in enumerate(node.operands): - var_type = operand_types[i] - other_type = operand_types[1 - i] - - if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): - # This should only be true at most once: there should be - # exactly two elements in node.operands and if the 'other type' is - # a singleton type, it by definition does not need to be narrowed: - # it already has the most precise type possible so does not need to - # be narrowed/included in the output map. - # - # TODO: Generalize this to handle the case where 'other_type' is - # a union of singleton types. + operand_types = [] + for expr in node.operands: + if expr not in type_map: + return {}, {} + operand_types.append(coerce_to_literal(type_map[expr])) + + type_maps = [] + for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()): + left_type = operand_types[i] + right_type = operand_types[i + 1] + + if_map = {} # type: TypeMap + else_map = {} # type: TypeMap + if operator in {'in', 'not in'}: + right_item_type = builtin_item_type(right_type) + if right_item_type is None or is_optional(right_item_type): + continue + if (isinstance(right_item_type, Instance) + and right_item_type.type.fullname() == 'builtins.object'): + continue + + if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE + and not is_literal_none(left_expr) and + is_overlapping_erased_types(left_type, right_item_type)): + if_map, else_map = {left_expr: remove_optional(left_type)}, {} + else: + continue + elif operator in {'==', '!='}: + if_map, else_map = self.narrow_given_equality( + left_expr, left_type, right_expr, right_type, assume_identity=False) + elif operator in {'is', 'is not'}: + if_map, else_map = self.narrow_given_equality( + left_expr, left_type, right_expr, right_type, assume_identity=True) + else: + continue - if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): - fallback_name = other_type.fallback.type.fullname() - var_type = try_expanding_enum_to_union(var_type, fallback_name) + if operator in {'not in', '!=', 'is not'}: + if_map, else_map = else_map, if_map - target_type = [TypeRange(other_type, is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, var_type, target_type) - break + type_maps.append((if_map, else_map)) - if is_not: - if_vars, else_vars = else_vars, if_vars - return if_vars, else_vars - # Check for `x == y` where x is of type Optional[T] and y is of type T - # or a type that overlaps with T (or vice versa). - elif node.operators == ['==']: - first_type = type_map[node.operands[0]] - second_type = type_map[node.operands[1]] - if is_optional(first_type) != is_optional(second_type): - if is_optional(first_type): - optional_type, comp_type = first_type, second_type - optional_expr = node.operands[0] - else: - optional_type, comp_type = second_type, first_type - optional_expr = node.operands[1] - if is_overlapping_erased_types(optional_type, comp_type): - return {optional_expr: remove_optional(optional_type)}, {} - elif node.operators in [['in'], ['not in']]: - expr = node.operands[0] - left_type = type_map[expr] - right_type = builtin_item_type(type_map[node.operands[1]]) - right_ok = right_type and (not is_optional(right_type) and - (not isinstance(right_type, Instance) or - right_type.type.fullname() != 'builtins.object')) - if (right_type and right_ok and is_optional(left_type) and - literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_erased_types(left_type, right_type)): - if node.operators == ['in']: - return {expr: remove_optional(left_type)}, {} - if node.operators == ['not in']: - return {}, {expr: remove_optional(left_type)} + if len(type_maps) == 0: + return {}, {} + elif len(type_maps) == 1: + return type_maps[0] + else: + # Comparisons like 'a == b == c is d' is the same thing as + # '(a == b) and (b == c) and (c is d)'. So after generating each + # individual comparison's typemaps, we "and" them together here. + # (Also see comments below where we handle the 'and' OpExpr.) + final_if_map, final_else_map = type_maps[0] + for if_map, else_map in type_maps[1:]: + final_if_map = and_conditional_maps(final_if_map, if_map) + final_else_map = or_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression # Not a supported isinstance check return {}, {} + def narrow_given_equality(self, + left_expr: Expression, + left_type: Type, + right_expr: Expression, + right_type: Type, + assume_identity: bool, + ) -> Tuple[TypeMap, TypeMap]: + """Assuming that the given 'left' and 'right' exprs are equal to each other, try + producing TypeMaps refining the types of either the left or right exprs (or neither, + if we can't learn anything from the comparison). + + For more details about what TypeMaps are, see the docstring in find_isinstance_check. + + If 'assume_identity' is true, assume that this comparison was done using an + identity comparison (left_expr is right_expr), not just an equality comparison + (left_expr == right_expr). Identity checks are not overridable, so we can infer + more information in that case. + """ + + # For the sake of simplicity, we currently attempt inferring a more precise type + # for just one of the two variables. + comparisons = [ + (left_expr, left_type, right_type), + (right_expr, right_type, left_type), + ] + + for expr, expr_type, other_type in comparisons: + # The 'expr' isn't an expression that we can refine the type of. Skip + # attempting to refine this expr. + if literal(expr) != LITERAL_TYPE: + continue + + # Case 1: If the 'other_type' is a singleton (only one value has + # the specified type), attempt to narrow 'expr_type' to just that + # singleton type. + if is_singleton_type(other_type): + if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): + if not assume_identity: + # Our checks need to be more conservative if the operand is + # '==' or '!=': all bets are off if either of the two operands + # has a custom `__eq__` or `__ne__` method. + # + # So, we permit this check to succeed only if 'other_type' does + # not define custom equality logic + if not uses_default_equality_checks(expr_type): + continue + if not uses_default_equality_checks(other_type.fallback): + continue + fallback_name = other_type.fallback.type.fullname() + expr_type = try_expanding_enum_to_union(expr_type, fallback_name) + + target_type = [TypeRange(other_type, is_upper_bound=False)] + return conditional_type_map(expr, expr_type, target_type) + + # Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'. + # + # Note: This check is actually strictly speaking unsafe: stripping away the 'None' + # would be unsound in the case where A defines an '__eq__' method that always + # returns 'True', for example. + # + # We implement this check partly for backwards-compatibility reasons and partly + # because those kinds of degenerate '__eq__' implementations are probably rare + # enough that this is fine in practice. + # + # We could also probably generalize this block to strip away *any* singleton type, + # if we were fine with a bit more unsoundness. + if is_optional(expr_type) and not is_optional(other_type): + if is_overlapping_erased_types(expr_type, other_type): + return {expr: remove_optional(expr_type)}, {} + + return {}, {} + # # Helpers # @@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') +def uses_default_equality_checks(typ: Type) -> bool: + """Returns 'true' if we know for certain that the given type is using + the default __eq__ and __ne__ checks defined in 'builtins.object'. + We can use this information to make more aggressive inferences when + analyzing things like equality checks. + + When in doubt, this function will conservatively bias towards + returning False. + """ + if isinstance(typ, UnionType): + return all(map(uses_default_equality_checks, typ.items)) + # TODO: Generalize this so it'll handle other types with fallbacks + if isinstance(typ, LiteralType): + typ = typ.fallback + if isinstance(typ, Instance): + typeinfo = typ.type + eq_sym = typeinfo.get('__eq__') + ne_sym = typeinfo.get('__ne__') + if eq_sym is None or ne_sym is None: + return False + return (eq_sym.fullname == 'builtins.object.__eq__' + and ne_sym.fullname == 'builtins.object.__ne__') + else: + return False + + def is_singleton_type(typ: Type) -> bool: """Returns 'true' if this type is a "singleton type" -- if there exists exactly only one runtime value associated with this type. diff --git a/mypy/nodes.py b/mypy/nodes.py index 44f5fe6100397..0ee5f798641f5 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1718,6 +1718,7 @@ class ComparisonExpr(Expression): def __init__(self, operators: List[str], operands: List[Expression]) -> None: super().__init__() + assert len(operators) + 1 == len(operands) self.operators = operators self.operands = operands self.method_types = [] @@ -1725,6 +1726,13 @@ def __init__(self, operators: List[str], operands: List[Expression]) -> None: def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_comparison_expr(self) + def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]: + """If this comparison expr is "a < b is c == d", yields the sequence + ("<", a, b), ("is", b, c), ("==", c, d) + """ + for i, operator in enumerate(self.operators): + yield operator, self.operands[i], self.operands[i + 1] + class SliceExpr(Expression): """Slice expression (e.g. 'x:y', 'x:', '::2' or ':'). diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 9f015f24986cc..72975fbe0daa2 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int' [out2] main:2: note: Revealed type is 'builtins.str' -[case testEnumReachabilityChecksBasic] +[case testEnumReachabilityChecksBasicIdentity] from enum import Enum from typing_extensions import Literal @@ -659,6 +659,54 @@ else: reveal_type(y) # No output here: this branch is unreachable [builtins fixtures/bool.pyi] +[case testEnumReachabilityChecksBasicEquality] +from enum import Enum +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Literal[Foo.A, Foo.B, Foo.C] +if x == Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif x == Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif x == Foo.C: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +if Foo.A == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C == x: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(x) # No output here: this branch is unreachable + +y: Foo +if y == Foo.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif y == Foo.B: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif y == Foo.C: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable + +if Foo.A == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +elif Foo.B == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +elif Foo.C == y: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' +else: + reveal_type(y) # No output here: this branch is unreachable +[builtins fixtures/bool.pyi] + [case testEnumReachabilityChecksIndirect] from enum import Enum from typing_extensions import Literal, Final @@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str: return 'PROCESSED: ' + response [builtins fixtures/primitives.pyi] + +[case testEnumReachabilityDisabledGivenCustomEquality] +from typing import Union +from enum import Enum + +class Parent(Enum): + def __ne__(self, other: object) -> bool: return True + +class Foo(Enum): + A = 1 + B = 2 + def __eq__(self, other: object) -> bool: return True + +class Bar(Parent): + A = 1 + B = 2 + +class Ok(Enum): + A = 1 + B = 2 + +x: Foo +if x is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + +if x == Foo.A: + reveal_type(x) # N: Revealed type is '__main__.Foo' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + +y: Bar +if y is Bar.A: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]' +else: + reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]' + +if y == Bar.A: + reveal_type(y) # N: Revealed type is '__main__.Bar' +else: + reveal_type(y) # N: Revealed type is '__main__.Bar' + +z1: Union[Bar, Ok] +if z1 is Ok.A: + reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]' +else: + reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]' + +z2: Union[Bar, Ok] +if z2 == Ok.A: + reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' +else: + reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChaining] +from enum import Enum +class Foo(Enum): + A = 1 + B = 2 + +x: Foo +y: Foo +if x is Foo.A is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' + +if x == Foo.A == y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' + reveal_type(y) # N: Revealed type is '__main__.Foo' +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 8f5313870e27c..7e17f8787ea23 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -485,6 +485,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'builtins.str' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'builtins.str' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithUnion] @@ -494,6 +498,10 @@ if x == '': reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is '': + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsNotOptionalWithOverlap] @@ -503,6 +511,10 @@ if x == object(): reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' +if x is object(): + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int]' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] [case testInferEqualsStillOptionalWithNoOverlap] diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index e7f240e919260..4b1abe22347b5 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -1196,6 +1196,7 @@ x = y reveal_type(x) # N: Revealed type is 'Tuple[builtins.int, builtins.int]' [case testTupleOverlapDifferentTuples] +# flags: --strict-optional from typing import Optional, Tuple class A: pass class B: pass