Skip to content

Commit

Permalink
Refine parent type when narrowing "lookup" expressions
Browse files Browse the repository at this point in the history
This diff adds support for the following pattern:

```python
from typing import Enum, List
from enum import Enum

class Key(Enum):
    A = 1
    B = 2

class Foo:
    key: Literal[Key.A]
    blah: List[int]

class Bar:
    key: Literal[Key.B]
    something: List[str]

x: Union[Foo, Bar]
if x.key is Key.A:
    reveal_type(x)  # Revealed type is 'Foo'
else:
    reveal_type(x)  # Revealed type is 'Bar'
```

In short, when we do `x.key is Key.A`, we "propagate" the information
we discovered about `x.key` up one level to refine the type of `x`.

We perform this propagation only when `x` is a Union and only when we
are doing member or index lookups into instances, typeddicts,
namedtuples, and tuples. For indexing operations, we have one additional
limitation: we *must* use a literal expression in order for narrowing
to work at all. Using Literal types or Final instances won't work;
See python#7905 for more details.

To put it another way, this adds support for tagged unions, I guess.

This more or less resolves python#7344.
We currently don't have support for narrowing based on string or int
literals, but that's a separate issue and should be resolved by
python#7169 (which I resumed work
on earlier this week).
  • Loading branch information
Michael0x2a committed Nov 9, 2019
1 parent 84126ab commit 0526133
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 20 deletions.
118 changes: 108 additions & 10 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import (
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
Sequence
Mapping, Sequence
)
from typing_extensions import Final

Expand Down Expand Up @@ -47,7 +47,9 @@
)
from mypy.typeops import (
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
erase_def_to_union_or_bound, erase_to_union_or_bound,
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
tuple_fallback, lookup_attribute_type, is_singleton_type, try_expanding_enum_to_union,
true_only, false_only, function_type,
)
from mypy import message_registry
Expand All @@ -72,9 +74,6 @@
from mypy.plugin import Plugin, CheckerPluginInterface
from mypy.sharedparse import BINARY_MAGIC_METHODS
from mypy.scope import Scope
from mypy.typeops import (
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union
)
from mypy import state, errorcodes as codes
from mypy.traverser import has_return_statement, all_return_statements
from mypy.errorcodes import ErrorCode
Expand Down Expand Up @@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression
Guaranteed to not return None, None. (But may return {}, {})
"""
if_map, else_map = self.find_isinstance_check_helper(node)
new_if_map = propagate_up_typemap_info(self.type_map, if_map)
new_else_map = propagate_up_typemap_info(self.type_map, else_map)
return new_if_map, new_else_map

def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]:
type_map = self.type_map
if is_true_literal(node):
return {}, None
Expand Down Expand Up @@ -3835,23 +3840,23 @@ def find_isinstance_check(self, node: Expression
else None)
return if_map, else_map
elif isinstance(node, OpExpr) and node.op == 'and':
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)

# (e1 and e2) is true if both e1 and e2 are true,
# and false if at least one of e1 and e2 is false.
return (and_conditional_maps(left_if_vars, right_if_vars),
or_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, OpExpr) and node.op == 'or':
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)

# (e1 or e2) is true if at least one of e1 or e2 is true,
# and false if both e1 and e2 are false.
return (or_conditional_maps(left_if_vars, right_if_vars),
and_conditional_maps(left_else_vars, right_else_vars))
elif isinstance(node, UnaryExpr) and node.op == 'not':
left, right = self.find_isinstance_check(node.expr)
left, right = self.find_isinstance_check_helper(node.expr)
return right, left

# Not a supported isinstance check
Expand Down Expand Up @@ -4780,3 +4785,96 @@ def has_bool_item(typ: ProperType) -> bool:
return any(is_named_instance(item, 'builtins.bool')
for item in typ.items)
return False


def propagate_up_typemap_info(existing_types: Mapping[Expression, Type],
new_types: TypeMap) -> TypeMap:
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
Specifically, this function accepts two mappings of expression to original types:
the original mapping (existing_types), and a new mapping (new_types) intended to
update the original.
This function iterates through new_types and attempts to use the information to try
refining the parent type if
"""
if new_types is None:
return None
output_map = {}
for expr, typ in new_types.items():
# The original inferred type should always be present in the output map, of course
output_map[expr] = typ

# Next, check and see if this expression is one that's attempting to
# "index" into the parent type. If so, grab both the parent and the "key".
keys = [] # type: Sequence[Union[str, int]]
if isinstance(expr, MemberExpr):
parent_expr = expr.expr
parent_type = existing_types.get(parent_expr)
variant_name = expr.name
keys = [variant_name]
elif isinstance(expr, IndexExpr):
parent_expr = expr.base
parent_type = existing_types.get(parent_expr)

variant_type = existing_types.get(expr.index)
if variant_type is None:
continue

str_literals = try_getting_str_literals_from_type(variant_type)
if str_literals is not None:
keys = str_literals
else:
int_literals = try_getting_int_literals_from_type(variant_type)
if int_literals is not None:
keys = int_literals
else:
continue
else:
continue

# We don't try inferring anything if we've either already inferred something for
# the parent expression or if the parent somehow doesn't already have an existing type
if parent_expr in new_types or parent_type is None:
continue

# If the parent isn't a union, we won't be able to perform any useful refinements.
# So, give up and carry on.
#
# TODO: We currently refine just the immediate parent. Should we also try refining
# any parents of the parents?
#
# One quick-and-dirty way of doing this would be to have the caller repeatedly run
# this function until we seem fixpoint, but that seems expensive.
parent_type = get_proper_type(parent_type)
if not isinstance(parent_type, UnionType):
continue

# Take each potential parent type in the union and try "indexing" into it using.
# Does the resulting type overlap with the deduced type of the original expression?
# If so, keep the parent type in the union.
new_parent_types = []
for item in parent_type.items:
item = get_proper_type(item)
member_types = []
for key in keys:
t = lookup_attribute_type(item, key)
if t is not None:
member_types.append(t)
member_type_for_item = make_simplified_union(member_types)
if member_type_for_item is None:
# We were unable to obtain the member type. So, we give up on refining this
# parent type entirely.
new_parent_types = []
break

if is_overlapping_types(member_type_for_item, typ):
new_parent_types.append(item)

# If none of the parent types overlap (if we derived an empty union), either
# we deliberately aborted or something went wrong. Deriving the uninhabited
# type seems unhelpful, so let's just skip refining the parent expression.
if new_parent_types:
output_map[parent_expr] = make_simplified_union(new_parent_types)

return output_map
3 changes: 3 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2704,6 +2704,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
index = e.index
left_type = get_proper_type(left_type)

# Visit the index, just to make sure we have a type for it available
self.accept(index)

if isinstance(left_type, UnionType):
original_type = original_type or left_type
return make_simplified_union([self.visit_index_with_type(typ, e,
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'check-isinstance.test',
'check-lists.test',
'check-namedtuple.test',
'check-narrowing.test',
'check-typeddict.test',
'check-type-aliases.test',
'check-ignore.test',
Expand Down
107 changes: 98 additions & 9 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set
from typing import cast, Optional, List, Sequence, Set, Union, TypeVar, Type as TypingType
import sys

from mypy.types import (
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType,
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
copy_type, TypeAliasType
)
Expand Down Expand Up @@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance:
return Instance(info, [join_type_list(typ.items)])


def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]:
"""Returns the Instance fallback for this type if one exists.
Otherwise, returns None.
"""
if isinstance(typ, Instance):
return typ
elif isinstance(typ, TupleType):
return tuple_fallback(typ)
elif isinstance(typ, TypedDictType):
return typ.fallback
elif isinstance(typ, FunctionLike):
return typ.fallback
elif isinstance(typ, LiteralType):
return typ.fallback
else:
return None


def type_object_type_from_function(signature: FunctionLike,
info: TypeInfo,
def_info: TypeInfo,
Expand Down Expand Up @@ -475,27 +494,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
2. 'typ' is a LiteralType containing a string
3. 'typ' is a UnionType containing only LiteralType of strings
"""
typ = get_proper_type(typ)

if isinstance(expr, StrExpr):
return [expr.value]

# TODO: See if we can eliminate this function and call the below one directly
return try_getting_str_literals_from_type(typ)


def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]:
"""If the given expression or type corresponds to a string Literal
or a union of string Literals, returns a list of the underlying strings.
Otherwise, returns None.
For example, if we had the type 'Literal["foo", "bar"]' as input, this function
would return a list of strings ["foo", "bar"].
"""
return try_getting_literals_from_type(typ, str, "builtins.str")


def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]:
"""If the given expression or type corresponds to an int Literal
or a union of int Literals, returns a list of the underlying ints.
Otherwise, returns None.
For example, if we had the type 'Literal[1, 2, 3]' as input, this function
would return a list of ints [1, 2, 3].
"""
return try_getting_literals_from_type(typ, int, "builtins.int")


T = TypeVar('T')


def try_getting_literals_from_type(typ: Type,
target_literal_type: TypingType[T],
target_fullname: str) -> Optional[List[T]]:
"""If the given expression or type corresponds to a Literal or
union of Literals where the underlying values corresponds to the given
target type, returns a list of those underlying values. Otherwise,
returns None.
"""
typ = get_proper_type(typ)

if isinstance(typ, Instance) and typ.last_known_value is not None:
possible_literals = [typ.last_known_value] # type: List[Type]
elif isinstance(typ, UnionType):
possible_literals = list(typ.items)
else:
possible_literals = [typ]

strings = []
literals = [] # type: List[T]
for lit in get_proper_types(possible_literals):
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname:
val = lit.value
assert isinstance(val, str)
strings.append(val)
if isinstance(val, target_literal_type):
literals.append(val)
else:
return None
else:
return None
return strings
return literals


def get_enum_values(typ: Instance) -> List[str]:
Expand Down Expand Up @@ -587,3 +645,34 @@ def coerce_to_literal(typ: Type) -> ProperType:
if len(enum_values) == 1:
return LiteralType(value=enum_values[0], fallback=typ)
return typ


def lookup_attribute_type(typ: Type, key: Union[str, int]) -> Optional[Type]:
typ = get_proper_type(typ)
if isinstance(key, int):
# Int keys apply to tuples and namedtuples
if isinstance(typ, TupleType):
try:
return typ.items[key]
except IndexError:
return None
else:
# Str keys apply to typed dicts, named tuples, instances, and anything that has
# an instance fallback
if isinstance(typ, TypedDictType):
return typ.items.get(key)

instance = try_getting_instance_fallback(typ)
if instance is None:
return None

symbol = instance.type.get(key)
if symbol is None:
return None

if symbol.type is None:
return None

return expand_type_by_instance(symbol.type, instance)

return None
Loading

0 comments on commit 0526133

Please sign in to comment.