diff --git a/mypy/checker.py b/mypy/checker.py index cd05643271211..284c536220bb5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -226,8 +226,10 @@ def accept(self, node: Node, type_context: Type = None) -> Type: else: return typ - def accept_loop(self, body: Node, else_body: Node = None) -> Type: + def accept_loop(self, body: Node, else_body: Node = None, *, + exit_condition: Node = None) -> Type: """Repeatedly type check a loop body until the frame doesn't change. + If exit_condition is set, assume it must be False on exit from the loop. Then check the else_body. """ @@ -240,6 +242,13 @@ def accept_loop(self, body: Node, else_body: Node = None) -> Type: if not self.binder.last_pop_changed: break self.binder.pop_loop_frame() + if exit_condition: + _, else_map = find_isinstance_check( + exit_condition, self.type_map, self.typing_mode_weak() + ) + if else_map: + for var, type in else_map.items(): + self.binder.push(var, type) if else_body: self.accept(else_body) @@ -1465,7 +1474,8 @@ def visit_if_stmt(self, s: IfStmt) -> Type: def visit_while_stmt(self, s: WhileStmt) -> Type: """Type check a while statement.""" - self.accept_loop(IfStmt([s.expr], [s.body], None), s.else_body) + self.accept_loop(IfStmt([s.expr], [s.body], None), s.else_body, + exit_condition=s.expr) def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> Type: diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index f8ffd61efaadf..d3632907e26a2 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -908,6 +908,42 @@ def bar() -> None: [out] main: note: In function "bar": +[case testWhileExitCondition1] +from typing import Union +x = 1 # type: Union[int, str] +while isinstance(x, int): + if bool(): + continue + x = 'a' +else: + reveal_type(x) # E: Revealed type is 'builtins.str' +reveal_type(x) # E: Revealed type is 'builtins.str' +[builtins fixtures/isinstance.py] + +[case testWhileExitCondition2] +from typing import Union +x = 1 # type: Union[int, str] +while isinstance(x, int): + if bool(): + break + x = 'a' +else: + reveal_type(x) # E: Revealed type is 'builtins.str' +reveal_type(x) # E: Revealed type is 'Union[builtins.int, builtins.str]' +[builtins fixtures/isinstance.py] + +[case testWhileLinkedList] +from typing import Union +LinkedList = Union['Cons', 'Nil'] +class Nil: pass +class Cons: + tail = None # type: LinkedList +def last(x: LinkedList) -> Nil: + while isinstance(x, Cons): + x = x.tail + return x +[builtins fixtures/isinstance.py] + [case testReturnAndFlow] def foo() -> int: return 1 and 2