Skip to content

Commit

Permalink
Refine parent type when narrowing "lookup" expressions
Browse files Browse the repository at this point in the history
This diff adds support for the following pattern:

```python
from typing import Enum, List
from enum import Enum

class Key(Enum):
    A = 1
    B = 2

class Foo:
    key: Literal[Key.A]
    blah: List[int]

class Bar:
    key: Literal[Key.B]
    something: List[str]

x: Union[Foo, Bar]
if x.key is Key.A:
    reveal_type(x)  # Revealed type is 'Foo'
else:
    reveal_type(x)  # Revealed type is 'Bar'
```

In short, when we do `x.key is Key.A`, we "propagate" the information
we discovered about `x.key` up one level to refine the type of `x`.

We perform this propagation only when `x` is a Union and only when we
are doing member or index lookups into instances, typeddicts,
namedtuples, and tuples. For indexing operations, we have one additional
limitation: we *must* use a literal expression in order for narrowing
to work at all. Using Literal types or Final instances won't work;
See python#7905 for more details.

To put it another way, this adds support for tagged unions, I guess.

This more or less resolves python#7344.
We currently don't have support for narrowing based on string or int
literals, but that's a separate issue and should be resolved by
python#7169 (which I resumed work
on earlier this week).
  • Loading branch information
Michael0x2a committed Nov 9, 2019
1 parent 84126ab commit 14fa27f
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 22 deletions.
184 changes: 173 additions & 11 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import (
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
Sequence
Mapping, Sequence
)
from typing_extensions import Final

Expand Down Expand Up @@ -43,11 +43,13 @@
)
import mypy.checkexpr
from mypy.checkmember import (
analyze_descriptor_access, type_object_type,
analyze_member_access, analyze_descriptor_access, type_object_type,
)
from mypy.typeops import (
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
erase_def_to_union_or_bound, erase_to_union_or_bound,
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
true_only, false_only, function_type,
)
from mypy import message_registry
Expand All @@ -72,9 +74,6 @@
from mypy.plugin import Plugin, CheckerPluginInterface
from mypy.sharedparse import BINARY_MAGIC_METHODS
from mypy.scope import Scope
from mypy.typeops import (
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union
)
from mypy import state, errorcodes as codes
from mypy.traverser import has_return_statement, all_return_statements
from mypy.errorcodes import ErrorCode
Expand Down Expand Up @@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression
Guaranteed to not return None, None. (But may return {}, {})
"""
if_map, else_map = self.find_isinstance_check_helper(node)
new_if_map = self.propagate_up_typemap_info(self.type_map, if_map)
new_else_map = self.propagate_up_typemap_info(self.type_map, else_map)
return new_if_map, new_else_map

def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]:
type_map = self.type_map
if is_true_literal(node):
return {}, None
Expand Down Expand Up @@ -3835,28 +3840,185 @@ def find_isinstance_check(self, node: Expression
else None)
return if_map, else_map
elif isinstance(node, OpExpr) and node.op == 'and':
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)

# (e1 and e2) is true if both e1 and e2 are true,
# and false if at least one of e1 and e2 is false.
return (and_conditional_maps(left_if_vars, right_if_vars),
or_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, OpExpr) and node.op == 'or':
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)

# (e1 or e2) is true if at least one of e1 or e2 is true,
# and false if both e1 and e2 are false.
return (or_conditional_maps(left_if_vars, right_if_vars),
and_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, UnaryExpr) and node.op == 'not':
left, right = self.find_isinstance_check(node.expr)
left, right = self.find_isinstance_check_helper(node.expr)
return right, left

# Not a supported isinstance check
return {}, {}

def propagate_up_typemap_info(self,
existing_types: Mapping[Expression, Type],
new_types: TypeMap) -> TypeMap:
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
Specifically, this function accepts two mappings of expression to original types:
the original mapping (existing_types), and a new mapping (new_types) intended to
update the original.
This function iterates through new_types and attempts to use the information to try
refining any parent types that happen to be unions.
For example, suppose there are two types "A = Tuple[int, int]" and "B = Tuple[str, str]".
Next, suppose that 'new_types' specifies the expression 'foo[0]' has a refined type
of 'int' and that 'foo' was previously deduced to be of type Union[A, B].
Then, this function will observe that since A[0] is an int and B[0] is not, the type of
'foo' can be further refined from Union[A, B] into just B.
We perform this kind of "parent narrowing" for member lookup expressions and indexing
expressions into tuples, namedtuples, and typeddicts. This narrowing is also performed
only once, for the immediate parents of any "lookup" expressions in `new_types`.
We return the newly refined map. This map is guaranteed to be a superset of 'new_types'.
"""
if new_types is None:
return None
output_map = {}
for expr, expr_type in new_types.items():
# The original inferred type should always be present in the output map, of course
output_map[expr] = expr_type

# Next, try using this information to refine the parent type, if applicable.
# Note that we currently refine just the immediate parent.
#
# TODO: Should we also try recursively refining any parents of the parents?
#
# One quick-and-dirty way of doing this would be to have the caller repeatedly run
# this function until we reach fixpoint; another way would be to modify
# 'refine_parent_type' to run in a loop. Both approaches seem expensive though.
new_mapping = self.refine_parent_type(existing_types, expr, expr_type)
for parent_expr, proposed_parent_type in new_mapping.items():
# We don't try inferring anything if we've already inferred something for
# the parent expression.
# TODO: Consider picking the narrower type instead of always discarding this?
if parent_expr in new_types:
continue
output_map[parent_expr] = proposed_parent_type
return output_map

def refine_parent_type(self,
existing_types: Mapping[Expression, Type],
expr: Expression,
expr_type: Type) -> Mapping[Expression, Type]:
"""Checks if the given expr is a 'lookup operation' into a union and refines the parent type
based on the 'expr_type'.
For more details about what a 'lookup operation' is and how we use the expr_type to refine
the parent type, see the docstring in 'propagate_up_typemap_info'.
"""

# First, check if this expression is one that's attempting to
# "lookup" some key in the parent type. If so, save the parent type
# and create function that will try replaying the same lookup
# operation against arbitrary types.
if isinstance(expr, MemberExpr):
parent_expr = expr.expr
parent_type = existing_types.get(parent_expr)
member_name = expr.name

def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
msg_copy = self.msg.clean_copy()
msg_copy.disable_count = 0
member_type = analyze_member_access(
name=member_name,
typ=new_parent_type,
context=parent_expr,
is_lvalue=False,
is_super=False,
is_operator=False,
msg=msg_copy,
original_type=new_parent_type,
chk=self,
in_literal_context=False,
)
if msg_copy.is_errors():
return None
else:
return member_type
elif isinstance(expr, IndexExpr):
parent_expr = expr.base
parent_type = existing_types.get(parent_expr)

index_type = existing_types.get(expr.index)
if index_type is None:
return {}

str_literals = try_getting_str_literals_from_type(index_type)
if str_literals is not None:
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
if not isinstance(new_parent_type, TypedDictType):
return None
try:
assert str_literals is not None
member_types = [new_parent_type.items[key] for key in str_literals]
except KeyError:
return None
return make_simplified_union(member_types)
else:
int_literals = try_getting_int_literals_from_type(index_type)
if int_literals is not None:
def replay_lookup(new_parent_type: ProperType) -> Optional[Type]:
if not isinstance(new_parent_type, TupleType):
return None
try:
assert int_literals is not None
member_types = [new_parent_type.items[key] for key in int_literals]
except IndexError:
return None
return make_simplified_union(member_types)
else:
return {}
else:
return {}

# If we somehow didn't previously derive the parent type, abort:
# something went wrong at an earlier stage.
if parent_type is None:
return {}

# We currently only try refining the parent type if it's a Union.
parent_type = get_proper_type(parent_type)
if not isinstance(parent_type, UnionType):
return {}

# Take each element in the parent union and replay the original lookup procedure
# to figure out which parents are compatible.
new_parent_types = []
for item in parent_type.items:
item = get_proper_type(item)
member_type = replay_lookup(item)
if member_type is None:
# We were unable to obtain the member type. So, we give up on refining this
# parent type entirely.
return {}

if is_overlapping_types(member_type, expr_type):
new_parent_types.append(item)

# If none of the parent types overlap (if we derived an empty union), something
# went wrong. We should never hit this case, but deriving the uninhabited type or
# reporting an error both seem unhelpful. So we abort.
if not new_parent_types:
return {}

return {parent_expr: make_simplified_union(new_parent_types)}

#
# Helpers
#
Expand Down
3 changes: 3 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,6 +2704,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
index = e.index
left_type = get_proper_type(left_type)

# Visit the index, just to make sure we have a type for it available
self.accept(index)

if isinstance(left_type, UnionType):
original_type = original_type or left_type
return make_simplified_union([self.visit_index_with_type(typ, e,
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'check-isinstance.test',
'check-lists.test',
'check-namedtuple.test',
'check-narrowing.test',
'check-typeddict.test',
'check-type-aliases.test',
'check-ignore.test',
Expand Down
78 changes: 68 additions & 10 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set
from typing import cast, Optional, List, Sequence, Set, TypeVar, Type as TypingType
import sys

from mypy.types import (
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType,
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
copy_type, TypeAliasType
)
from mypy.nodes import (
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, ARG_POS,
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
Expression, StrExpr, Var
)
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance:
return Instance(info, [join_type_list(typ.items)])


def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]:
"""Returns the Instance fallback for this type if one exists.
Otherwise, returns None.
"""
if isinstance(typ, Instance):
return typ
elif isinstance(typ, TupleType):
return tuple_fallback(typ)
elif isinstance(typ, TypedDictType):
return typ.fallback
elif isinstance(typ, FunctionLike):
return typ.fallback
elif isinstance(typ, LiteralType):
return typ.fallback
else:
return None


def type_object_type_from_function(signature: FunctionLike,
info: TypeInfo,
def_info: TypeInfo,
Expand Down Expand Up @@ -475,27 +494,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
2. 'typ' is a LiteralType containing a string
3. 'typ' is a UnionType containing only LiteralType of strings
"""
typ = get_proper_type(typ)

if isinstance(expr, StrExpr):
return [expr.value]

# TODO: See if we can eliminate this function and call the below one directly
return try_getting_str_literals_from_type(typ)


def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]:
"""If the given expression or type corresponds to a string Literal
or a union of string Literals, returns a list of the underlying strings.
Otherwise, returns None.
For example, if we had the type 'Literal["foo", "bar"]' as input, this function
would return a list of strings ["foo", "bar"].
"""
return try_getting_literals_from_type(typ, str, "builtins.str")


def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]:
"""If the given expression or type corresponds to an int Literal
or a union of int Literals, returns a list of the underlying ints.
Otherwise, returns None.
For example, if we had the type 'Literal[1, 2, 3]' as input, this function
would return a list of ints [1, 2, 3].
"""
return try_getting_literals_from_type(typ, int, "builtins.int")


T = TypeVar('T')


def try_getting_literals_from_type(typ: Type,
target_literal_type: TypingType[T],
target_fullname: str) -> Optional[List[T]]:
"""If the given expression or type corresponds to a Literal or
union of Literals where the underlying values corresponds to the given
target type, returns a list of those underlying values. Otherwise,
returns None.
"""
typ = get_proper_type(typ)

if isinstance(typ, Instance) and typ.last_known_value is not None:
possible_literals = [typ.last_known_value] # type: List[Type]
elif isinstance(typ, UnionType):
possible_literals = list(typ.items)
else:
possible_literals = [typ]

strings = []
literals = [] # type: List[T]
for lit in get_proper_types(possible_literals):
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname:
val = lit.value
assert isinstance(val, str)
strings.append(val)
if isinstance(val, target_literal_type):
literals.append(val)
else:
return None
else:
return None
return strings
return literals


def get_enum_values(typ: Instance) -> List[str]:
Expand Down
Loading

0 comments on commit 14fa27f

Please sign in to comment.