Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make reachability code understand chained comparisons #7169

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 150 additions & 58 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3622,67 +3622,61 @@ def find_isinstance_check(self, node: Expression
vartype = type_map[expr]
return self.conditional_callable_type_map(expr, vartype)
elif isinstance(node, ComparisonExpr):
operand_types = [coerce_to_literal(type_map[expr])
for expr in node.operands if expr in type_map]

is_not = node.operators == ['is not']
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands):
if_vars = {} # type: TypeMap
else_vars = {} # type: TypeMap

for i, expr in enumerate(node.operands):
var_type = operand_types[i]
other_type = operand_types[1 - i]

if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type):
# This should only be true at most once: there should be
# exactly two elements in node.operands and if the 'other type' is
# a singleton type, it by definition does not need to be narrowed:
# it already has the most precise type possible so does not need to
# be narrowed/included in the output map.
#
# TODO: Generalize this to handle the case where 'other_type' is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this TODO moved somewhere or is it now fixed? Do we have regression test in the latter case?

# a union of singleton types.
operand_types = []
for expr in node.operands:
if expr not in type_map:
return {}, {}
operand_types.append(coerce_to_literal(type_map[expr]))

type_maps = []
for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()):
left_type = operand_types[i]
right_type = operand_types[i + 1]

if_map = {} # type: TypeMap
else_map = {} # type: TypeMap
if operator in {'in', 'not in'}:
right_item_type = builtin_item_type(right_type)
if right_item_type is None or is_optional(right_item_type):
continue
if (isinstance(right_item_type, Instance)
and right_item_type.type.fullname() == 'builtins.object'):
continue

if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE
and not is_literal_none(left_expr) and
is_overlapping_erased_types(left_type, right_item_type)):
if_map, else_map = {left_expr: remove_optional(left_type)}, {}
else:
continue
elif operator in {'==', '!='}:
if_map, else_map = self.narrow_given_equality(
left_expr, left_type, right_expr, right_type, assume_identity=False)
elif operator in {'is', 'is not'}:
if_map, else_map = self.narrow_given_equality(
left_expr, left_type, right_expr, right_type, assume_identity=True)
else:
continue

if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
fallback_name = other_type.fallback.type.fullname()
var_type = try_expanding_enum_to_union(var_type, fallback_name)
if operator in {'not in', '!=', 'is not'}:
if_map, else_map = else_map, if_map

target_type = [TypeRange(other_type, is_upper_bound=False)]
if_vars, else_vars = conditional_type_map(expr, var_type, target_type)
break
type_maps.append((if_map, else_map))

if is_not:
if_vars, else_vars = else_vars, if_vars
return if_vars, else_vars
# Check for `x == y` where x is of type Optional[T] and y is of type T
# or a type that overlaps with T (or vice versa).
elif node.operators == ['==']:
first_type = type_map[node.operands[0]]
second_type = type_map[node.operands[1]]
if is_optional(first_type) != is_optional(second_type):
if is_optional(first_type):
optional_type, comp_type = first_type, second_type
optional_expr = node.operands[0]
else:
optional_type, comp_type = second_type, first_type
optional_expr = node.operands[1]
if is_overlapping_erased_types(optional_type, comp_type):
return {optional_expr: remove_optional(optional_type)}, {}
elif node.operators in [['in'], ['not in']]:
expr = node.operands[0]
left_type = type_map[expr]
right_type = get_proper_type(builtin_item_type(type_map[node.operands[1]]))
right_ok = right_type and (not is_optional(right_type) and
(not isinstance(right_type, Instance) or
right_type.type.fullname() != 'builtins.object'))
if (right_type and right_ok and is_optional(left_type) and
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and
is_overlapping_erased_types(left_type, right_type)):
if node.operators == ['in']:
return {expr: remove_optional(left_type)}, {}
if node.operators == ['not in']:
return {}, {expr: remove_optional(left_type)}
if len(type_maps) == 0:
return {}, {}
elif len(type_maps) == 1:
return type_maps[0]
else:
# Comparisons like 'a == b == c is d' is the same thing as
# '(a == b) and (b == c) and (c is d)'. So after generating each
# individual comparison's typemaps, we "and" them together here.
# (Also see comments below where we handle the 'and' OpExpr.)
final_if_map, final_else_map = type_maps[0]
for if_map, else_map in type_maps[1:]:
final_if_map = and_conditional_maps(final_if_map, if_map)
final_else_map = or_conditional_maps(final_else_map, else_map)
return final_if_map, final_else_map
elif isinstance(node, RefExpr):
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
# respectively
Expand Down Expand Up @@ -3719,6 +3713,78 @@ def find_isinstance_check(self, node: Expression
# Not a supported isinstance check
return {}, {}

def narrow_given_equality(self,
left_expr: Expression,
left_type: Type,
right_expr: Expression,
right_type: Type,
assume_identity: bool,
) -> Tuple[TypeMap, TypeMap]:
"""Assuming that the given 'left' and 'right' exprs are equal to each other, try
producing TypeMaps refining the types of either the left or right exprs (or neither,
if we can't learn anything from the comparison).

For more details about what TypeMaps are, see the docstring in find_isinstance_check.

If 'assume_identity' is true, assume that this comparison was done using an
identity comparison (left_expr is right_expr), not just an equality comparison
(left_expr == right_expr). Identity checks are not overridable, so we can infer
more information in that case.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also update docstring style here.


# For the sake of simplicity, we currently attempt inferring a more precise type
# for just one of the two variables.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is this? If we actually do this, I think it is better to use the type of right one to restrict the type of left one.

comparisons = [
(left_expr, left_type, right_type),
(right_expr, right_type, left_type),
]

for expr, expr_type, other_type in comparisons:
# The 'expr' isn't an expression that we can refine the type of. Skip
# attempting to refine this expr.
if literal(expr) != LITERAL_TYPE:
continue

# Case 1: If the 'other_type' is a singleton (only one value has
# the specified type), attempt to narrow 'expr_type' to just that
# singleton type.
if is_singleton_type(other_type):
if isinstance(other_type, LiteralType) and other_type.is_enum_literal():
if not assume_identity:
# Our checks need to be more conservative if the operand is
# '==' or '!=': all bets are off if either of the two operands
# has a custom `__eq__` or `__ne__` method.
#
# So, we permit this check to succeed only if 'other_type' does
# not define custom equality logic
if not uses_default_equality_checks(expr_type):
continue
if not uses_default_equality_checks(other_type.fallback):
continue
fallback_name = other_type.fallback.type.fullname()
expr_type = try_expanding_enum_to_union(expr_type, fallback_name)

target_type = [TypeRange(other_type, is_upper_bound=False)]
return conditional_type_map(expr, expr_type, target_type)

# Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'.
#
# Note: This check is actually strictly speaking unsafe: stripping away the 'None'
# would be unsound in the case where A defines an '__eq__' method that always
# returns 'True', for example.
#
# We implement this check partly for backwards-compatibility reasons and partly
# because those kinds of degenerate '__eq__' implementations are probably rare
# enough that this is fine in practice.
#
# We could also probably generalize this block to strip away *any* singleton type,
# if we were fine with a bit more unsoundness.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prepend the last sentence with TODO:?

if is_optional(expr_type) and not is_optional(other_type):
if is_overlapping_erased_types(expr_type, other_type):
return {expr: remove_optional(expr_type)}, {}

return {}, {}

#
# Helpers
#
Expand Down Expand Up @@ -4615,6 +4681,32 @@ def is_private(node_name: str) -> bool:
return node_name.startswith('__') and not node_name.endswith('__')


def uses_default_equality_checks(typ: Type) -> bool:
"""Returns 'true' if we know for certain that the given type is using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Returns 'true' if we know for certain that the given type is using
"""Returns 'True' if we know for certain that the given type is using

the default __eq__ and __ne__ checks defined in 'builtins.object'.
We can use this information to make more aggressive inferences when
analyzing things like equality checks.

When in doubt, this function will conservatively bias towards
returning False.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use this style*

"""Single line summary.

Longer description...
"""
* This review comment was sponsored by Jukka

if isinstance(typ, UnionType):
return all(map(uses_default_equality_checks, typ.items))
# TODO: Generalize this so it'll handle other types with fallbacks
if isinstance(typ, LiteralType):
typ = typ.fallback
if isinstance(typ, Instance):
typeinfo = typ.type
eq_sym = typeinfo.get('__eq__')
ne_sym = typeinfo.get('__ne__')
if eq_sym is None or ne_sym is None:
return False
return (eq_sym.fullname == 'builtins.object.__eq__'
and ne_sym.fullname == 'builtins.object.__ne__')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks too strong to me. Wouldn't fullname.startswith('builtins.') be enough (like we do for --strict-equality)?

else:
return False


def is_singleton_type(typ: Type) -> bool:
"""Returns 'true' if this type is a "singleton type" -- if there exists
exactly only one runtime value associated with this type.
Expand Down
8 changes: 8 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,13 +1734,21 @@ class ComparisonExpr(Expression):

def __init__(self, operators: List[str], operands: List[Expression]) -> None:
super().__init__()
assert len(operators) + 1 == len(operands)
self.operators = operators
self.operands = operands
self.method_types = []

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_comparison_expr(self)

def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]:
"""If this comparison expr is "a < b is c == d", yields the sequence
("<", a, b), ("is", b, c), ("==", c, d)
"""
for i, operator in enumerate(self.operators):
yield operator, self.operands[i], self.operands[i + 1]


class SliceExpr(Expression):
"""Slice expression (e.g. 'x:y', 'x:', '::2' or ':').
Expand Down
128 changes: 127 additions & 1 deletion test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int'
[out2]
main:2: note: Revealed type is 'builtins.str'

[case testEnumReachabilityChecksBasic]
[case testEnumReachabilityChecksBasicIdentity]
from enum import Enum
from typing_extensions import Literal

Expand Down Expand Up @@ -659,6 +659,54 @@ else:
reveal_type(y) # No output here: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksBasicEquality]
from enum import Enum
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

x: Literal[Foo.A, Foo.B, Foo.C]
if x == Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif x == Foo.B:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif x == Foo.C:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

if Foo.A == x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B == x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C == x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable

y: Foo
if y == Foo.A:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif y == Foo.B:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif y == Foo.C:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable

if Foo.A == y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
elif Foo.B == y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
elif Foo.C == y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksIndirect]
from enum import Enum
from typing_extensions import Literal, Final
Expand Down Expand Up @@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str:
return 'PROCESSED: ' + response

[builtins fixtures/primitives.pyi]

[case testEnumReachabilityDisabledGivenCustomEquality]
from typing import Union
from enum import Enum

class Parent(Enum):
def __ne__(self, other: object) -> bool: return True

class Foo(Enum):
A = 1
B = 2
def __eq__(self, other: object) -> bool: return True

class Bar(Parent):
A = 1
B = 2

class Ok(Enum):
A = 1
B = 2

x: Foo
if x is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'

if x == Foo.A:
reveal_type(x) # N: Revealed type is '__main__.Foo'
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'

y: Bar
if y is Bar.A:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]'
else:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]'

if y == Bar.A:
reveal_type(y) # N: Revealed type is '__main__.Bar'
else:
reveal_type(y) # N: Revealed type is '__main__.Bar'

z1: Union[Bar, Ok]
if z1 is Ok.A:
reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]'
else:
reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]'

z2: Union[Bar, Ok]
if z2 == Ok.A:
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
else:
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]'
[builtins fixtures/primitives.pyi]

[case testEnumReachabilityWithChaining]
from enum import Enum
class Foo(Enum):
A = 1
B = 2

x: Foo
y: Foo
if x is Foo.A is y:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'

if x == Foo.A == y:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does x == y == Foo.A work? (It feels like we ought to support this or I'm not sure how much this helps?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is an important case, and we should have a test for this.

reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is '__main__.Foo'
reveal_type(y) # N: Revealed type is '__main__.Foo'
[builtins fixtures/primitives.pyi]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add more tests for chaining. The PR title mentions chaining, but there is literally only one test about this.
In particular I would add tests for:

  • Mixed kinds of comparisons in the chain like x == 2 > y is 3
  • Mixing chained comparisons using and/or with other chain and other kinds like callable() and isinstance()

Loading