diff --git a/mypy/checker.py b/mypy/checker.py index 557ceb8a71c09..f610ebd50bed8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6,7 +6,7 @@ from typing import ( Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, - Sequence + Mapping, Sequence ) from typing_extensions import Final @@ -47,7 +47,9 @@ ) from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, - erase_def_to_union_or_bound, erase_to_union_or_bound, + erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, + try_getting_str_literals_from_type, try_getting_int_literals_from_type, + tuple_fallback, lookup_attribute_type, is_singleton_type, try_expanding_enum_to_union, true_only, false_only, function_type, ) from mypy import message_registry @@ -72,9 +74,6 @@ from mypy.plugin import Plugin, CheckerPluginInterface from mypy.sharedparse import BINARY_MAGIC_METHODS from mypy.scope import Scope -from mypy.typeops import ( - tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union -) from mypy import state, errorcodes as codes from mypy.traverser import has_return_statement, all_return_statements from mypy.errorcodes import ErrorCode @@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression Guaranteed to not return None, None. (But may return {}, {}) """ + if_map, else_map = self.find_isinstance_check_helper(node) + new_if_map = propagate_up_typemap_info(self.type_map, if_map) + new_else_map = propagate_up_typemap_info(self.type_map, else_map) + return new_if_map, new_else_map + + def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]: type_map = self.type_map if is_true_literal(node): return {}, None @@ -3835,23 +3840,23 @@ def find_isinstance_check(self, node: Expression else None) return if_map, else_map elif isinstance(node, OpExpr) and node.op == 'and': - left_if_vars, left_else_vars = self.find_isinstance_check(node.left) - right_if_vars, right_else_vars = self.find_isinstance_check(node.right) + left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) + right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) # (e1 and e2) is true if both e1 and e2 are true, # and false if at least one of e1 and e2 is false. return (and_conditional_maps(left_if_vars, right_if_vars), or_conditional_maps(left_else_vars, right_else_vars)) elif isinstance(node, OpExpr) and node.op == 'or': - left_if_vars, left_else_vars = self.find_isinstance_check(node.left) - right_if_vars, right_else_vars = self.find_isinstance_check(node.right) + left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left) + right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right) # (e1 or e2) is true if at least one of e1 or e2 is true, # and false if both e1 and e2 are false. return (or_conditional_maps(left_if_vars, right_if_vars), and_conditional_maps(left_else_vars, right_else_vars)) elif isinstance(node, UnaryExpr) and node.op == 'not': - left, right = self.find_isinstance_check(node.expr) + left, right = self.find_isinstance_check_helper(node.expr) return right, left # Not a supported isinstance check @@ -4780,3 +4785,96 @@ def has_bool_item(typ: ProperType) -> bool: return any(is_named_instance(item, 'builtins.bool') for item in typ.items) return False + + +def propagate_up_typemap_info(existing_types: Mapping[Expression, Type], + new_types: TypeMap) -> TypeMap: + """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. + + Specifically, this function accepts two mappings of expression to original types: + the original mapping (existing_types), and a new mapping (new_types) intended to + update the original. + + This function iterates through new_types and attempts to use the information to try + refining the parent type if + """ + if new_types is None: + return None + output_map = {} + for expr, typ in new_types.items(): + # The original inferred type should always be present in the output map, of course + output_map[expr] = typ + + # Next, check and see if this expression is one that's attempting to + # "index" into the parent type. If so, grab both the parent and the "key". + keys = [] # type: Sequence[Union[str, int]] + if isinstance(expr, MemberExpr): + parent_expr = expr.expr + parent_type = existing_types.get(parent_expr) + variant_name = expr.name + keys = [variant_name] + elif isinstance(expr, IndexExpr): + parent_expr = expr.base + parent_type = existing_types.get(parent_expr) + + variant_type = existing_types.get(expr.index) + if variant_type is None: + continue + + str_literals = try_getting_str_literals_from_type(variant_type) + if str_literals is not None: + keys = str_literals + else: + int_literals = try_getting_int_literals_from_type(variant_type) + if int_literals is not None: + keys = int_literals + else: + continue + else: + continue + + # We don't try inferring anything if we've either already inferred something for + # the parent expression or if the parent somehow doesn't already have an existing type + if parent_expr in new_types or parent_type is None: + continue + + # If the parent isn't a union, we won't be able to perform any useful refinements. + # So, give up and carry on. + # + # TODO: We currently refine just the immediate parent. Should we also try refining + # any parents of the parents? + # + # One quick-and-dirty way of doing this would be to have the caller repeatedly run + # this function until we seem fixpoint, but that seems expensive. + parent_type = get_proper_type(parent_type) + if not isinstance(parent_type, UnionType): + continue + + # Take each potential parent type in the union and try "indexing" into it using. + # Does the resulting type overlap with the deduced type of the original expression? + # If so, keep the parent type in the union. + new_parent_types = [] + for item in parent_type.items: + item = get_proper_type(item) + member_types = [] + for key in keys: + t = lookup_attribute_type(item, key) + if t is not None: + member_types.append(t) + member_type_for_item = make_simplified_union(member_types) + if member_type_for_item is None: + # We were unable to obtain the member type. So, we give up on refining this + # parent type entirely. + new_parent_types = [] + break + + if is_overlapping_types(member_type_for_item, typ): + new_parent_types.append(item) + + # If none of the parent types overlap (if we derived an empty union), either + # we deliberately aborted or something went wrong. Deriving the uninhabited + # type seems unhelpful, so let's just skip refining the parent expression. + if new_parent_types: + output_map[parent_expr] = make_simplified_union(new_parent_types) + + return output_map diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b86801e25f1b8..a13c5d11809e7 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2704,6 +2704,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, index = e.index left_type = get_proper_type(left_type) + # Visit the index, just to make sure we have a type for it available + self.accept(index) + if isinstance(left_type, UnionType): original_type = original_type or left_type return make_simplified_union([self.visit_index_with_type(typ, e, diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 5fd5405ec4e8c..2747d1c034d1c 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -46,6 +46,7 @@ 'check-isinstance.test', 'check-lists.test', 'check-namedtuple.test', + 'check-narrowing.test', 'check-typeddict.test', 'check-type-aliases.test', 'check-ignore.test', diff --git a/mypy/typeops.py b/mypy/typeops.py index b26aa8b3ea73c..78fb3cdc76329 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,12 +5,12 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set +from typing import cast, Optional, List, Sequence, Set, Union, TypeVar, Type as TypingType import sys from mypy.types import ( TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, - TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, + TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, copy_type, TypeAliasType ) @@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance: return Instance(info, [join_type_list(typ.items)]) +def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]: + """Returns the Instance fallback for this type if one exists. + + Otherwise, returns None. + """ + if isinstance(typ, Instance): + return typ + elif isinstance(typ, TupleType): + return tuple_fallback(typ) + elif isinstance(typ, TypedDictType): + return typ.fallback + elif isinstance(typ, FunctionLike): + return typ.fallback + elif isinstance(typ, LiteralType): + return typ.fallback + else: + return None + + def type_object_type_from_function(signature: FunctionLike, info: TypeInfo, def_info: TypeInfo, @@ -475,11 +494,48 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]] 2. 'typ' is a LiteralType containing a string 3. 'typ' is a UnionType containing only LiteralType of strings """ - typ = get_proper_type(typ) - if isinstance(expr, StrExpr): return [expr.value] + # TODO: See if we can eliminate this function and call the below one directly + return try_getting_str_literals_from_type(typ) + + +def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]: + """If the given expression or type corresponds to a string Literal + or a union of string Literals, returns a list of the underlying strings. + Otherwise, returns None. + + For example, if we had the type 'Literal["foo", "bar"]' as input, this function + would return a list of strings ["foo", "bar"]. + """ + return try_getting_literals_from_type(typ, str, "builtins.str") + + +def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]: + """If the given expression or type corresponds to an int Literal + or a union of int Literals, returns a list of the underlying ints. + Otherwise, returns None. + + For example, if we had the type 'Literal[1, 2, 3]' as input, this function + would return a list of ints [1, 2, 3]. + """ + return try_getting_literals_from_type(typ, int, "builtins.int") + + +T = TypeVar('T') + + +def try_getting_literals_from_type(typ: Type, + target_literal_type: TypingType[T], + target_fullname: str) -> Optional[List[T]]: + """If the given expression or type corresponds to a Literal or + union of Literals where the underlying values corresponds to the given + target type, returns a list of those underlying values. Otherwise, + returns None. + """ + typ = get_proper_type(typ) + if isinstance(typ, Instance) and typ.last_known_value is not None: possible_literals = [typ.last_known_value] # type: List[Type] elif isinstance(typ, UnionType): @@ -487,15 +543,17 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]] else: possible_literals = [typ] - strings = [] + literals = [] # type: List[T] for lit in get_proper_types(possible_literals): - if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str': + if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname: val = lit.value - assert isinstance(val, str) - strings.append(val) + if isinstance(val, target_literal_type): + literals.append(val) + else: + return None else: return None - return strings + return literals def get_enum_values(typ: Instance) -> List[str]: @@ -587,3 +645,34 @@ def coerce_to_literal(typ: Type) -> ProperType: if len(enum_values) == 1: return LiteralType(value=enum_values[0], fallback=typ) return typ + + +def lookup_attribute_type(typ: Type, key: Union[str, int]) -> Optional[Type]: + typ = get_proper_type(typ) + if isinstance(key, int): + # Int keys apply to tuples and namedtuples + if isinstance(typ, TupleType): + try: + return typ.items[key] + except IndexError: + return None + else: + # Str keys apply to typed dicts, named tuples, instances, and anything that has + # an instance fallback + if isinstance(typ, TypedDictType): + return typ.items.get(key) + + instance = try_getting_instance_fallback(typ) + if instance is None: + return None + + symbol = instance.type.get(key) + if symbol is None: + return None + + if symbol.type is None: + return None + + return expand_type_by_instance(symbol.type, instance) + + return None diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test new file mode 100644 index 0000000000000..8ba764f56afbb --- /dev/null +++ b/test-data/unit/check-narrowing.test @@ -0,0 +1,262 @@ +[case testNarrowingParentWithEnumsBasic] +from enum import Enum +from dataclasses import dataclass +from typing import NamedTuple, Tuple, Union +from typing_extensions import Literal, TypedDict + +class Key(Enum): + A = 1 + B = 2 + C = 3 + +class Object1: + key: Literal[Key.A] + foo: int +class Object2: + key: Literal[Key.B] + bar: str + +@dataclass +class Dataclass1: + key: Literal[Key.A] + foo: int +@dataclass +class Dataclass2: + key: Literal[Key.B] + foo: str + +class NamedTuple1(NamedTuple): + key: Literal[Key.A] + foo: int +class NamedTuple2(NamedTuple): + key: Literal[Key.B] + foo: str + +Tuple1 = Tuple[Literal[Key.A], int] +Tuple2 = Tuple[Literal[Key.B], str] + +x1: Union[Object1, Object2] +if x1.key is Key.A: + reveal_type(x1) # N: Revealed type is '__main__.Object1' + reveal_type(x1.key) # N: Revealed type is 'Literal[__main__.Key.A]' +else: + reveal_type(x1) # N: Revealed type is '__main__.Object2' + reveal_type(x1.key) # N: Revealed type is 'Literal[__main__.Key.B]' + +x2: Union[Dataclass1, Dataclass2] +if x2.key is Key.A: + reveal_type(x2) # N: Revealed type is '__main__.Dataclass1' + reveal_type(x2.key) # N: Revealed type is 'Literal[__main__.Key.A]' +else: + reveal_type(x2) # N: Revealed type is '__main__.Dataclass2' + reveal_type(x2.key) # N: Revealed type is 'Literal[__main__.Key.B]' + +x3: Union[NamedTuple1, NamedTuple2] +if x3.key is Key.A: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal[__main__.Key.A], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3.key) # N: Revealed type is 'Literal[__main__.Key.A]' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal[__main__.Key.B], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3.key) # N: Revealed type is 'Literal[__main__.Key.B]' +if x3[0] is Key.A: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal[__main__.Key.A], builtins.int, fallback=__main__.NamedTuple1]' + reveal_type(x3[0]) # N: Revealed type is 'Literal[__main__.Key.A]' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[Literal[__main__.Key.B], builtins.str, fallback=__main__.NamedTuple2]' + reveal_type(x3[0]) # N: Revealed type is 'Literal[__main__.Key.B]' + +x4: Union[Tuple1, Tuple2] +if x4[0] is Key.A: + reveal_type(x4) # N: Revealed type is 'Tuple[Literal[__main__.Key.A], builtins.int]' + reveal_type(x4[0]) # N: Revealed type is 'Literal[__main__.Key.A]' +else: + reveal_type(x4) # N: Revealed type is 'Tuple[Literal[__main__.Key.B], builtins.str]' + reveal_type(x4[0]) # N: Revealed type is 'Literal[__main__.Key.B]' + +[case testNarrowingParentWithIsInstanceBasic] +from dataclasses import dataclass +from typing import NamedTuple, Tuple, Union +from typing_extensions import TypedDict + +class Object1: + key: int +class Object2: + key: str + +@dataclass +class Dataclass1: + key: int +@dataclass +class Dataclass2: + key: str + +class NamedTuple1(NamedTuple): + key: int +class NamedTuple2(NamedTuple): + key: str + +Tuple1 = Tuple[int] +Tuple2 = Tuple[str] + +x1: Union[Object1, Object2] +if isinstance(x1.key, int): + reveal_type(x1) # N: Revealed type is '__main__.Object1' +else: + reveal_type(x1) # N: Revealed type is '__main__.Object2' + +x2: Union[Dataclass1, Dataclass2] +if isinstance(x2.key, int): + reveal_type(x2) # N: Revealed type is '__main__.Dataclass1' +else: + reveal_type(x2) # N: Revealed type is '__main__.Dataclass2' + +x3: Union[NamedTuple1, NamedTuple2] +if isinstance(x3.key, int): + reveal_type(x3) # N: Revealed type is 'Tuple[builtins.int, fallback=__main__.NamedTuple1]' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[builtins.str, fallback=__main__.NamedTuple2]' +if isinstance(x3[0], int): + reveal_type(x3) # N: Revealed type is 'Tuple[builtins.int, fallback=__main__.NamedTuple1]' +else: + reveal_type(x3) # N: Revealed type is 'Tuple[builtins.str, fallback=__main__.NamedTuple2]' + +x4: Union[Tuple1, Tuple2] +if isinstance(x4[0], int): + reveal_type(x4) # N: Revealed type is 'Tuple[builtins.int]' +else: + reveal_type(x4) # N: Revealed type is 'Tuple[builtins.str]' +[builtins fixtures/isinstance.pyi] + +[case testNarrowingParentMultipleKeys] +# flags: --warn-unreachable +from enum import Enum +from typing import Union +from typing_extensions import Literal + +class Key(Enum): + A = 1 + B = 2 + C = 3 + D = 4 + +class Object1: + key: Literal[Key.A, Key.C] +class Object2: + key: Literal[Key.B, Key.C] + +x: Union[Object1, Object2] +if x.key is Key.A: + reveal_type(x) # N: Revealed type is '__main__.Object1' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' + +if x.key is Key.C: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' + +if x.key is Key.D: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2]' + +[case testNarrowingParentWithMultipleParents] +from enum import Enum +from typing import Union +from typing_extensions import Literal + +class Key(Enum): + A = 1 + B = 2 + C = 3 + +class Object1: + key: Literal[Key.A] +class Object2: + key: Literal[Key.B] +class Object3: + key: Literal[Key.C] +class Object4: + key: str + +x: Union[Object1, Object2, Object3, Object4] +if x.key is Key.A: + reveal_type(x) # N: Revealed type is '__main__.Object1' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object2, __main__.Object3, __main__.Object4]' + +if isinstance(x.key, str): + reveal_type(x) # N: Revealed type is '__main__.Object4' +else: + reveal_type(x) # N: Revealed type is 'Union[__main__.Object1, __main__.Object2, __main__.Object3]' +[builtins fixtures/isinstance.pyi] + +[case testNarrowingParentsWithGenerics] +from typing import Union, TypeVar, Generic + +T = TypeVar('T') +class Wrapper(Generic[T]): + key: T + +x: Union[Wrapper[int], Wrapper[str]] +if isinstance(x.key, int): + reveal_type(x) # N: Revealed type is '__main__.Wrapper[builtins.int]' +else: + reveal_type(x) # N: Revealed type is '__main__.Wrapper[builtins.str]' +[builtins fixtures/isinstance.pyi] + +[case testNarrowingParentWithParentMixtures] +from enum import Enum +from typing import Union, NamedTuple +from typing_extensions import Literal, TypedDict + +class Key(Enum): + A = 1 + B = 2 + C = 3 + +class KeyedObject: + key: Literal[Key.A] +class KeyedTypedDict(TypedDict): + key: Literal[Key.B] +class KeyedNamedTuple(NamedTuple): + key: Literal[Key.C] + +ok_mixture: Union[KeyedObject, KeyedNamedTuple] +if ok_mixture.key is Key.A: + reveal_type(ok_mixture) # N: Revealed type is '__main__.KeyedObject' +else: + reveal_type(ok_mixture) # N: Revealed type is 'Tuple[Literal[__main__.Key.C], fallback=__main__.KeyedNamedTuple]' + +# Each reveal_type below really ought to be a union -- e.g. we ought short-circuit +# and skip inferring anything about the parents. Currently, we overreach in some +# cases and infer something a bit non-sensical due to how we're normalizing +# "lookup" operations. +# +# This is a bit confusing from a usability standpoint, but is probably fine: +# we don't guarantee sensible results after errors anyways. (And making sure +# these nonsensical lookups result in an error is the main purpose of this +# test case). + +impossible_mixture: Union[KeyedObject, KeyedTypedDict] +if impossible_mixture.key is Key.A: # E: Item "KeyedTypedDict" of "Union[KeyedObject, KeyedTypedDict]" has no attribute "key" + reveal_type(impossible_mixture) # N: Revealed type is '__main__.KeyedObject' +else: + reveal_type(impossible_mixture) # N: Revealed type is 'Union[__main__.KeyedObject, TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]})]' + +if impossible_mixture["key"] is Key.A: # E: Value of type "Union[KeyedObject, KeyedTypedDict]" is not indexable + reveal_type(impossible_mixture) # N: Revealed type is '__main__.KeyedObject' +else: + reveal_type(impossible_mixture) # N: Revealed type is 'Union[__main__.KeyedObject, TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]})]' + +weird_mixture: Union[KeyedTypedDict, KeyedNamedTuple] +if weird_mixture["key"] is Key.B: # E: Invalid tuple index type (actual type "str", expected type "Union[int, slice]") + reveal_type(weird_mixture) # N: Revealed type is 'TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]})' +else: + reveal_type(weird_mixture) # N: Revealed type is 'Union[TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]}), Tuple[Literal[__main__.Key.C], fallback=__main__.KeyedNamedTuple]]' + +if weird_mixture[0] is Key.B: # E: TypedDict key must be a string literal; expected one of ('key') + reveal_type(weird_mixture) # N: Revealed type is 'TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]})' +else: + reveal_type(weird_mixture) # N: Revealed type is 'Union[TypedDict('__main__.KeyedTypedDict', {'key': Literal[__main__.Key.B]}), Tuple[Literal[__main__.Key.C], fallback=__main__.KeyedNamedTuple]]' +[builtins fixtures/slice.pyi] diff --git a/test-data/unit/fixtures/tuple.pyi b/test-data/unit/fixtures/tuple.pyi index 6e000a7699fdc..686e2dd55818d 100644 --- a/test-data/unit/fixtures/tuple.pyi +++ b/test-data/unit/fixtures/tuple.pyi @@ -21,7 +21,8 @@ class function: pass class ellipsis: pass # We need int and slice for indexing tuples. -class int: pass +class int: + def __neg__(self) -> 'int': pass class slice: pass class bool: pass class str: pass # For convenience