-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make reachability code understand chained comparisons #7169
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3622,67 +3622,61 @@ def find_isinstance_check(self, node: Expression | |||||
vartype = type_map[expr] | ||||||
return self.conditional_callable_type_map(expr, vartype) | ||||||
elif isinstance(node, ComparisonExpr): | ||||||
operand_types = [coerce_to_literal(type_map[expr]) | ||||||
for expr in node.operands if expr in type_map] | ||||||
|
||||||
is_not = node.operators == ['is not'] | ||||||
if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): | ||||||
if_vars = {} # type: TypeMap | ||||||
else_vars = {} # type: TypeMap | ||||||
|
||||||
for i, expr in enumerate(node.operands): | ||||||
var_type = operand_types[i] | ||||||
other_type = operand_types[1 - i] | ||||||
|
||||||
if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): | ||||||
# This should only be true at most once: there should be | ||||||
# exactly two elements in node.operands and if the 'other type' is | ||||||
# a singleton type, it by definition does not need to be narrowed: | ||||||
# it already has the most precise type possible so does not need to | ||||||
# be narrowed/included in the output map. | ||||||
# | ||||||
# TODO: Generalize this to handle the case where 'other_type' is | ||||||
# a union of singleton types. | ||||||
operand_types = [] | ||||||
for expr in node.operands: | ||||||
if expr not in type_map: | ||||||
return {}, {} | ||||||
operand_types.append(coerce_to_literal(type_map[expr])) | ||||||
|
||||||
type_maps = [] | ||||||
for i, (operator, left_expr, right_expr) in enumerate(node.pairwise()): | ||||||
left_type = operand_types[i] | ||||||
right_type = operand_types[i + 1] | ||||||
|
||||||
if_map = {} # type: TypeMap | ||||||
else_map = {} # type: TypeMap | ||||||
if operator in {'in', 'not in'}: | ||||||
right_item_type = builtin_item_type(right_type) | ||||||
if right_item_type is None or is_optional(right_item_type): | ||||||
continue | ||||||
if (isinstance(right_item_type, Instance) | ||||||
and right_item_type.type.fullname() == 'builtins.object'): | ||||||
continue | ||||||
|
||||||
if (is_optional(left_type) and literal(left_expr) == LITERAL_TYPE | ||||||
and not is_literal_none(left_expr) and | ||||||
is_overlapping_erased_types(left_type, right_item_type)): | ||||||
if_map, else_map = {left_expr: remove_optional(left_type)}, {} | ||||||
else: | ||||||
continue | ||||||
elif operator in {'==', '!='}: | ||||||
if_map, else_map = self.narrow_given_equality( | ||||||
left_expr, left_type, right_expr, right_type, assume_identity=False) | ||||||
elif operator in {'is', 'is not'}: | ||||||
if_map, else_map = self.narrow_given_equality( | ||||||
left_expr, left_type, right_expr, right_type, assume_identity=True) | ||||||
else: | ||||||
continue | ||||||
|
||||||
if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): | ||||||
fallback_name = other_type.fallback.type.fullname() | ||||||
var_type = try_expanding_enum_to_union(var_type, fallback_name) | ||||||
if operator in {'not in', '!=', 'is not'}: | ||||||
if_map, else_map = else_map, if_map | ||||||
|
||||||
target_type = [TypeRange(other_type, is_upper_bound=False)] | ||||||
if_vars, else_vars = conditional_type_map(expr, var_type, target_type) | ||||||
break | ||||||
type_maps.append((if_map, else_map)) | ||||||
|
||||||
if is_not: | ||||||
if_vars, else_vars = else_vars, if_vars | ||||||
return if_vars, else_vars | ||||||
# Check for `x == y` where x is of type Optional[T] and y is of type T | ||||||
# or a type that overlaps with T (or vice versa). | ||||||
elif node.operators == ['==']: | ||||||
first_type = type_map[node.operands[0]] | ||||||
second_type = type_map[node.operands[1]] | ||||||
if is_optional(first_type) != is_optional(second_type): | ||||||
if is_optional(first_type): | ||||||
optional_type, comp_type = first_type, second_type | ||||||
optional_expr = node.operands[0] | ||||||
else: | ||||||
optional_type, comp_type = second_type, first_type | ||||||
optional_expr = node.operands[1] | ||||||
if is_overlapping_erased_types(optional_type, comp_type): | ||||||
return {optional_expr: remove_optional(optional_type)}, {} | ||||||
elif node.operators in [['in'], ['not in']]: | ||||||
expr = node.operands[0] | ||||||
left_type = type_map[expr] | ||||||
right_type = get_proper_type(builtin_item_type(type_map[node.operands[1]])) | ||||||
right_ok = right_type and (not is_optional(right_type) and | ||||||
(not isinstance(right_type, Instance) or | ||||||
right_type.type.fullname() != 'builtins.object')) | ||||||
if (right_type and right_ok and is_optional(left_type) and | ||||||
literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and | ||||||
is_overlapping_erased_types(left_type, right_type)): | ||||||
if node.operators == ['in']: | ||||||
return {expr: remove_optional(left_type)}, {} | ||||||
if node.operators == ['not in']: | ||||||
return {}, {expr: remove_optional(left_type)} | ||||||
if len(type_maps) == 0: | ||||||
return {}, {} | ||||||
elif len(type_maps) == 1: | ||||||
return type_maps[0] | ||||||
else: | ||||||
# Comparisons like 'a == b == c is d' is the same thing as | ||||||
# '(a == b) and (b == c) and (c is d)'. So after generating each | ||||||
# individual comparison's typemaps, we "and" them together here. | ||||||
# (Also see comments below where we handle the 'and' OpExpr.) | ||||||
final_if_map, final_else_map = type_maps[0] | ||||||
for if_map, else_map in type_maps[1:]: | ||||||
final_if_map = and_conditional_maps(final_if_map, if_map) | ||||||
final_else_map = or_conditional_maps(final_else_map, else_map) | ||||||
return final_if_map, final_else_map | ||||||
elif isinstance(node, RefExpr): | ||||||
# Restrict the type of the variable to True-ish/False-ish in the if and else branches | ||||||
# respectively | ||||||
|
@@ -3719,6 +3713,78 @@ def find_isinstance_check(self, node: Expression | |||||
# Not a supported isinstance check | ||||||
return {}, {} | ||||||
|
||||||
def narrow_given_equality(self, | ||||||
left_expr: Expression, | ||||||
left_type: Type, | ||||||
right_expr: Expression, | ||||||
right_type: Type, | ||||||
assume_identity: bool, | ||||||
) -> Tuple[TypeMap, TypeMap]: | ||||||
"""Assuming that the given 'left' and 'right' exprs are equal to each other, try | ||||||
producing TypeMaps refining the types of either the left or right exprs (or neither, | ||||||
if we can't learn anything from the comparison). | ||||||
|
||||||
For more details about what TypeMaps are, see the docstring in find_isinstance_check. | ||||||
|
||||||
If 'assume_identity' is true, assume that this comparison was done using an | ||||||
identity comparison (left_expr is right_expr), not just an equality comparison | ||||||
(left_expr == right_expr). Identity checks are not overridable, so we can infer | ||||||
more information in that case. | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also update docstring style here. |
||||||
|
||||||
# For the sake of simplicity, we currently attempt inferring a more precise type | ||||||
# for just one of the two variables. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How important is this? If we actually do this, I think it is better to use the type of right one to restrict the type of left one. |
||||||
comparisons = [ | ||||||
(left_expr, left_type, right_type), | ||||||
(right_expr, right_type, left_type), | ||||||
] | ||||||
|
||||||
for expr, expr_type, other_type in comparisons: | ||||||
# The 'expr' isn't an expression that we can refine the type of. Skip | ||||||
# attempting to refine this expr. | ||||||
if literal(expr) != LITERAL_TYPE: | ||||||
continue | ||||||
|
||||||
# Case 1: If the 'other_type' is a singleton (only one value has | ||||||
# the specified type), attempt to narrow 'expr_type' to just that | ||||||
# singleton type. | ||||||
if is_singleton_type(other_type): | ||||||
if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): | ||||||
if not assume_identity: | ||||||
# Our checks need to be more conservative if the operand is | ||||||
# '==' or '!=': all bets are off if either of the two operands | ||||||
# has a custom `__eq__` or `__ne__` method. | ||||||
# | ||||||
# So, we permit this check to succeed only if 'other_type' does | ||||||
# not define custom equality logic | ||||||
if not uses_default_equality_checks(expr_type): | ||||||
continue | ||||||
if not uses_default_equality_checks(other_type.fallback): | ||||||
continue | ||||||
fallback_name = other_type.fallback.type.fullname() | ||||||
expr_type = try_expanding_enum_to_union(expr_type, fallback_name) | ||||||
|
||||||
target_type = [TypeRange(other_type, is_upper_bound=False)] | ||||||
return conditional_type_map(expr, expr_type, target_type) | ||||||
|
||||||
# Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'. | ||||||
# | ||||||
# Note: This check is actually strictly speaking unsafe: stripping away the 'None' | ||||||
# would be unsound in the case where A defines an '__eq__' method that always | ||||||
# returns 'True', for example. | ||||||
# | ||||||
# We implement this check partly for backwards-compatibility reasons and partly | ||||||
# because those kinds of degenerate '__eq__' implementations are probably rare | ||||||
# enough that this is fine in practice. | ||||||
# | ||||||
# We could also probably generalize this block to strip away *any* singleton type, | ||||||
# if we were fine with a bit more unsoundness. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prepend the last sentence with |
||||||
if is_optional(expr_type) and not is_optional(other_type): | ||||||
if is_overlapping_erased_types(expr_type, other_type): | ||||||
return {expr: remove_optional(expr_type)}, {} | ||||||
|
||||||
return {}, {} | ||||||
|
||||||
# | ||||||
# Helpers | ||||||
# | ||||||
|
@@ -4615,6 +4681,32 @@ def is_private(node_name: str) -> bool: | |||||
return node_name.startswith('__') and not node_name.endswith('__') | ||||||
|
||||||
|
||||||
def uses_default_equality_checks(typ: Type) -> bool: | ||||||
"""Returns 'true' if we know for certain that the given type is using | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
the default __eq__ and __ne__ checks defined in 'builtins.object'. | ||||||
We can use this information to make more aggressive inferences when | ||||||
analyzing things like equality checks. | ||||||
|
||||||
When in doubt, this function will conservatively bias towards | ||||||
returning False. | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use this style*
* This review comment was sponsored by Jukka |
||||||
if isinstance(typ, UnionType): | ||||||
return all(map(uses_default_equality_checks, typ.items)) | ||||||
# TODO: Generalize this so it'll handle other types with fallbacks | ||||||
if isinstance(typ, LiteralType): | ||||||
typ = typ.fallback | ||||||
if isinstance(typ, Instance): | ||||||
typeinfo = typ.type | ||||||
eq_sym = typeinfo.get('__eq__') | ||||||
ne_sym = typeinfo.get('__ne__') | ||||||
if eq_sym is None or ne_sym is None: | ||||||
return False | ||||||
return (eq_sym.fullname == 'builtins.object.__eq__' | ||||||
and ne_sym.fullname == 'builtins.object.__ne__') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks too strong to me. Wouldn't |
||||||
else: | ||||||
return False | ||||||
|
||||||
|
||||||
def is_singleton_type(typ: Type) -> bool: | ||||||
"""Returns 'true' if this type is a "singleton type" -- if there exists | ||||||
exactly only one runtime value associated with this type. | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -611,7 +611,7 @@ main:2: note: Revealed type is 'builtins.int' | |
[out2] | ||
main:2: note: Revealed type is 'builtins.str' | ||
|
||
[case testEnumReachabilityChecksBasic] | ||
[case testEnumReachabilityChecksBasicIdentity] | ||
from enum import Enum | ||
from typing_extensions import Literal | ||
|
||
|
@@ -659,6 +659,54 @@ else: | |
reveal_type(y) # No output here: this branch is unreachable | ||
[builtins fixtures/bool.pyi] | ||
|
||
[case testEnumReachabilityChecksBasicEquality] | ||
from enum import Enum | ||
from typing_extensions import Literal | ||
|
||
class Foo(Enum): | ||
A = 1 | ||
B = 2 | ||
C = 3 | ||
|
||
x: Literal[Foo.A, Foo.B, Foo.C] | ||
if x == Foo.A: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
elif x == Foo.B: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' | ||
elif x == Foo.C: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' | ||
else: | ||
reveal_type(x) # No output here: this branch is unreachable | ||
|
||
if Foo.A == x: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
elif Foo.B == x: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' | ||
elif Foo.C == x: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' | ||
else: | ||
reveal_type(x) # No output here: this branch is unreachable | ||
|
||
y: Foo | ||
if y == Foo.A: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
elif y == Foo.B: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' | ||
elif y == Foo.C: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' | ||
else: | ||
reveal_type(y) # No output here: this branch is unreachable | ||
|
||
if Foo.A == y: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
elif Foo.B == y: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' | ||
elif Foo.C == y: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' | ||
else: | ||
reveal_type(y) # No output here: this branch is unreachable | ||
[builtins fixtures/bool.pyi] | ||
|
||
[case testEnumReachabilityChecksIndirect] | ||
from enum import Enum | ||
from typing_extensions import Literal, Final | ||
|
@@ -854,3 +902,81 @@ def process(response: Union[str, Reason] = '') -> str: | |
return 'PROCESSED: ' + response | ||
|
||
[builtins fixtures/primitives.pyi] | ||
|
||
[case testEnumReachabilityDisabledGivenCustomEquality] | ||
from typing import Union | ||
from enum import Enum | ||
|
||
class Parent(Enum): | ||
def __ne__(self, other: object) -> bool: return True | ||
|
||
class Foo(Enum): | ||
A = 1 | ||
B = 2 | ||
def __eq__(self, other: object) -> bool: return True | ||
|
||
class Bar(Parent): | ||
A = 1 | ||
B = 2 | ||
|
||
class Ok(Enum): | ||
A = 1 | ||
B = 2 | ||
|
||
x: Foo | ||
if x is Foo.A: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
else: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' | ||
|
||
if x == Foo.A: | ||
reveal_type(x) # N: Revealed type is '__main__.Foo' | ||
else: | ||
reveal_type(x) # N: Revealed type is '__main__.Foo' | ||
|
||
y: Bar | ||
if y is Bar.A: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.A]' | ||
else: | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Bar.B]' | ||
|
||
if y == Bar.A: | ||
reveal_type(y) # N: Revealed type is '__main__.Bar' | ||
else: | ||
reveal_type(y) # N: Revealed type is '__main__.Bar' | ||
|
||
z1: Union[Bar, Ok] | ||
if z1 is Ok.A: | ||
reveal_type(z1) # N: Revealed type is 'Literal[__main__.Ok.A]' | ||
else: | ||
reveal_type(z1) # N: Revealed type is 'Union[__main__.Bar, Literal[__main__.Ok.B]]' | ||
|
||
z2: Union[Bar, Ok] | ||
if z2 == Ok.A: | ||
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' | ||
else: | ||
reveal_type(z2) # N: Revealed type is 'Union[__main__.Bar, __main__.Ok]' | ||
[builtins fixtures/primitives.pyi] | ||
|
||
[case testEnumReachabilityWithChaining] | ||
from enum import Enum | ||
class Foo(Enum): | ||
A = 1 | ||
B = 2 | ||
|
||
x: Foo | ||
y: Foo | ||
if x is Foo.A is y: | ||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
else: | ||
reveal_type(x) # N: Revealed type is '__main__.Foo' | ||
reveal_type(y) # N: Revealed type is '__main__.Foo' | ||
|
||
if x == Foo.A == y: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does x == y == Foo.A work? (It feels like we ought to support this or I'm not sure how much this helps?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this is an important case, and we should have a test for this. |
||
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' | ||
else: | ||
reveal_type(x) # N: Revealed type is '__main__.Foo' | ||
reveal_type(y) # N: Revealed type is '__main__.Foo' | ||
[builtins fixtures/primitives.pyi] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add more tests for chaining. The PR title mentions chaining, but there is literally only one test about this.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this TODO moved somewhere or is it now fixed? Do we have regression test in the latter case?