Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for narrowing Literals using equality #8151

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 110 additions & 47 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import (
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
Sequence, Mapping, Generic, AbstractSet
Sequence, Mapping, Generic, AbstractSet, Callable
)
from typing_extensions import Final

Expand Down Expand Up @@ -50,7 +50,8 @@
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
true_only, false_only, function_type, TypeVarExtractor,
true_only, false_only, function_type, TypeVarExtractor, custom_special_method,
is_literal_type_like,
)
from mypy import message_registry
from mypy.subtypes import (
Expand Down Expand Up @@ -3890,20 +3891,59 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM

partial_type_maps = []
for operator, expr_indices in simplified_operator_list:
if operator in {'is', 'is not'}:
if_map, else_map = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
)
elif operator in {'==', '!='}:
if_map, else_map = self.refine_equality_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
)
if operator in {'is', 'is not', '==', '!='}:
# is_valid_target:
# Controls which types we're allowed to narrow exprs to. Note that
# we cannot use 'is_literal_type_like' in both cases since doing
# 'x = 10000 + 1; x is 10001' is not always True in all Python impls.
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
#
# coerce_only_in_literal_context:
# If true, coerce types into literal types only if one or more of
# the provided exprs contains an explicit Literal type. This could
# technically be set to any arbitrary value, but it seems being liberal
# with narrowing when using 'is' and conservative when using '==' seems
# to break the least amount of real-world code.
#
# should_narrow_by_identity:
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
# that could cause identity-based narrowing to produce invalid results.
if operator in {'is', 'is not'}:
is_valid_target = is_singleton_type # type: Callable[[Type], bool]
coerce_only_in_literal_context = False
should_narrow_by_identity = True
else:
is_valid_target = is_exactly_literal_type
coerce_only_in_literal_context = True

def has_no_custom_eq_checks(t: Type) -> bool:
return not custom_special_method(t, '__eq__', check_all=False) \
and not custom_special_method(t, '__ne__', check_all=False)
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))

if_map = {} # type: TypeMap
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved
else_map = {} # type: TypeMap
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)

# Strictly speaking, we should also skip this check if the objects in the expr
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
# assume nobody would actually create a custom objects that considers itself
# equal to None.
if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
)
elif operator in {'in', 'not in'}:
assert len(expr_indices) == 2
left_index, right_index = expr_indices
Expand Down Expand Up @@ -4146,8 +4186,10 @@ def refine_identity_comparison_expression(self,
operand_types: List[Type],
chain_indices: List[int],
narrowable_operand_indices: AbstractSet[int],
is_valid_target: Callable[[ProperType], bool],
coerce_only_in_literal_context: bool,
) -> Tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining expressions used in an identity comparison.
"""Produces conditional type maps refining exprs used in an identity/equality comparison.
Michael0x2a marked this conversation as resolved.
Show resolved Hide resolved

The 'operands' and 'operand_types' lists should be the full list of operands used
in the overall comparison expression. The 'chain_indices' list is the list of indices
Expand All @@ -4163,30 +4205,45 @@ def refine_identity_comparison_expression(self,
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
to refine the types of: that is, all operands that will potentially be a part of
the output TypeMaps.

Although this function could theoretically try setting the types of the operands
in the chains to the meet, doing that causes too many issues in real-world code.
Instead, we use 'is_valid_target' to identify which of the given chain types
we could plausibly use as the refined type for the expressions in the chain.

Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
expressions in the chain to a Literal type. Performing this coercion is sometimes
too aggressive of a narrowing, depending on context.
"""
singleton = None # type: Optional[ProperType]
possible_singleton_indices = []
should_coerce = True
if coerce_only_in_literal_context:
should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices)

target = None # type: Optional[Type]
possible_target_indices = []
for i in chain_indices:
coerced_type = coerce_to_literal(operand_types[i])
if not is_singleton_type(coerced_type):
expr_type = operand_types[i]
if should_coerce:
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
if singleton and not is_same_type(singleton, coerced_type):
# We have multiple disjoint singleton types. So the 'if' branch
if target and not is_same_type(target, expr_type):
# We have multiple disjoint target types. So the 'if' branch
# must be unreachable.
return None, {}
singleton = coerced_type
possible_singleton_indices.append(i)
target = expr_type
possible_target_indices.append(i)

# There's nothing we can currently infer if none of the operands are singleton types,
# There's nothing we can currently infer if none of the operands are valid targets,
# so we end early and infer nothing.
if singleton is None:
if target is None:
return {}, {}

# If possible, use an unassignable expression as the singleton.
# We skip refining the type of the singleton below, so ideally we'd
# If possible, use an unassignable expression as the target.
# We skip refining the type of the target below, so ideally we'd
# want to pick an expression we were going to skip anyways.
singleton_index = -1
for i in possible_singleton_indices:
for i in possible_target_indices:
if i not in narrowable_operand_indices:
singleton_index = i

Expand Down Expand Up @@ -4215,20 +4272,21 @@ def refine_identity_comparison_expression(self,
# currently will just mark the whole branch as unreachable if either operand is
# narrowed to <uninhabited>.
if singleton_index == -1:
singleton_index = possible_singleton_indices[-1]
singleton_index = possible_target_indices[-1]

enum_name = None
if isinstance(singleton, LiteralType) and singleton.is_enum_literal():
enum_name = singleton.fallback.type.fullname
target = get_proper_type(target)
if isinstance(target, LiteralType) and target.is_enum_literal():
enum_name = target.fallback.type.fullname

target_type = [TypeRange(singleton, is_upper_bound=False)]
target_type = [TypeRange(target, is_upper_bound=False)]

partial_type_maps = []
for i in chain_indices:
# If we try refining a singleton against itself, conditional_type_map
# If we try refining a type against itself, conditional_type_map
# will end up assuming that the 'else' branch is unreachable. This is
# typically not what we want: generally the user will intend for the
# singleton type to be some fixed 'sentinel' value and will want to refine
# target type to be some fixed 'sentinel' value and will want to refine
# the other exprs against this one instead.
if i == singleton_index:
continue
Expand All @@ -4246,17 +4304,16 @@ def refine_identity_comparison_expression(self,

return reduce_partial_conditional_maps(partial_type_maps)

def refine_equality_comparison_expression(self,
operands: List[Expression],
operand_types: List[Type],
chain_indices: List[int],
narrowable_operand_indices: AbstractSet[int],
) -> Tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining expressions used in an equality comparison.
def refine_away_none_in_comparison(self,
operands: List[Expression],
operand_types: List[Type],
chain_indices: List[int],
narrowable_operand_indices: AbstractSet[int],
) -> Tuple[TypeMap, TypeMap]:
"""Produces conditional type maps refining away None in an identity/equality chain.

For more details, see the docstring of 'refine_equality_comparison' up above.
The only difference is that this function is for refining equality operations
(e.g. 'a == b == c') instead of identity ('a is b is c').
For more details about what the different arguments mean, see the
docstring of 'refine_identity_comparison_expression' up above.
"""
non_optional_types = []
for i in chain_indices:
Expand Down Expand Up @@ -4749,7 +4806,7 @@ class Foo(Enum):
return False

parent_type = get_proper_type(parent_type)
member_type = coerce_to_literal(member_type)
member_type = get_proper_type(coerce_to_literal(member_type))
if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType):
return False

Expand Down Expand Up @@ -5540,3 +5597,9 @@ def has_bool_item(typ: ProperType) -> bool:
return any(is_named_instance(item, 'builtins.bool')
for item in typ.items)
return False


# TODO: why can't we define this as an inline function?
# Does mypyc not support them?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean lambda by "inline function"? What exactly goes wrong when you try?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I actually don't remember. I'll try experimenting with this a little later today and see if I can repro.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah, I don't seem to be able to repro this anymore? Maybe it was fixed after I rebased or maybe I just doing something wrong before, but defining an inline function seems to be working fine now.

Defining a lambda unfortunately causes flake8 to complain -- it doesn't like it when you assign a lambda to a variable.

Anyways, I moved this function back to where it's being used.

def is_exactly_literal_type(t: Type) -> bool:
return isinstance(get_proper_type(t), LiteralType)
47 changes: 3 additions & 44 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
SYMBOL_FUNCBASE_TYPES
)
from mypy.literals import literal
from mypy import nodes
Expand All @@ -51,15 +50,16 @@
from mypy import erasetype
from mypy.checkmember import analyze_member_access, type_object_type
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
from mypy.checkstrformat import StringFormatterChecker, custom_special_method
from mypy.checkstrformat import StringFormatterChecker
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
from mypy.typeops import (
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
function_type, callable_type, try_getting_str_literals
function_type, callable_type, try_getting_str_literals, custom_special_method,
is_literal_type_like,
)
import mypy.errorcodes as codes

Expand Down Expand Up @@ -4266,24 +4266,6 @@ def merge_typevars_in_callables_by_name(
return output, variables


def is_literal_type_like(t: Optional[Type]) -> bool:
"""Returns 'true' if the given type context is potentially either a LiteralType,
a Union of LiteralType, or something similar.
"""
t = get_proper_type(t)
if t is None:
return False
elif isinstance(t, LiteralType):
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
elif isinstance(t, TypeVarType):
return (is_literal_type_like(t.upper_bound)
or any(is_literal_type_like(item) for item in t.values))
else:
return False


def try_getting_literal(typ: Type) -> ProperType:
"""If possible, get a more precise literal type for a given type."""
typ = get_proper_type(typ)
Expand All @@ -4305,29 +4287,6 @@ def is_expr_literal_type(node: Expression) -> bool:
return False


def custom_equality_method(typ: Type) -> bool:
"""Does this type have a custom __eq__() method?"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
method = typ.type.get('__eq__')
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
return not method.node.info.fullname.startswith('builtins.')
return False
if isinstance(typ, UnionType):
return any(custom_equality_method(t) for t in typ.items)
if isinstance(typ, TupleType):
return custom_equality_method(tuple_fallback(typ))
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __eq__ on the metaclass for class objects.
return custom_equality_method(typ.fallback)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False


def has_bytes_component(typ: Type, py2: bool = False) -> bool:
"""Is this one of builtin byte types, or a union that contains it?"""
typ = get_proper_type(typ)
Expand Down
35 changes: 3 additions & 32 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from mypy.types import (
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
CallableType, LiteralType, get_proper_types
LiteralType, get_proper_types
)
from mypy.nodes import (
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2,
SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
)
import mypy.errorcodes as codes

Expand All @@ -35,7 +35,7 @@
from mypy import message_registry
from mypy.messages import MessageBuilder
from mypy.maptype import map_instance_to_supertype
from mypy.typeops import tuple_fallback
from mypy.typeops import custom_special_method
from mypy.subtypes import is_subtype
from mypy.parse import parse

Expand Down Expand Up @@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool:
elif isinstance(typ, UnionType):
return any(has_type_component(t, fullname) for t in typ.relevant_items())
return False


def custom_special_method(typ: Type, name: str,
check_all: bool = False) -> bool:
"""Does this type have a custom special method such as __format__() or __eq__()?

If check_all is True ensure all items of a union have a custom method, not just some.
"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
method = typ.type.get(name)
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
if method.node.info:
return not method.node.info.fullname.startswith('builtins.')
return False
if isinstance(typ, UnionType):
if check_all:
return all(custom_special_method(t, name, check_all) for t in typ.items)
return any(custom_special_method(t, name) for t in typ.items)
if isinstance(typ, TupleType):
return custom_special_method(tuple_fallback(typ), name)
if isinstance(typ, CallableType) and typ.is_type_obj():
# Look up __method__ on the metaclass for class objects.
return custom_special_method(typ.fallback, name)
if isinstance(typ, AnyType):
# Avoid false positives in uncertain cases.
return True
# TODO: support other types (see ExpressionChecker.has_member())?
return False
Loading