Skip to content

Commit

Permalink
More helpful type guards (#14238)
Browse files Browse the repository at this point in the history
Fixes #13199
Refs #14425
  • Loading branch information
A5rocks committed Jan 31, 2023
1 parent 7c14eba commit 28c67cb
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 10 deletions.
22 changes: 19 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5350,10 +5350,26 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
# the first argument might be used as a kwarg
called_type = get_proper_type(self.lookup_type(node.callee))
assert isinstance(called_type, (CallableType, Overloaded))

# *assuming* the overloaded function is correct, there's a couple cases:
# 1) The first argument has different names, but is pos-only. We don't
# care about this case, the argument must be passed positionally.
# 2) The first argument allows keyword reference, therefore must be the
# same between overloads.
name = called_type.items[0].arg_names[0]

if name in node.arg_names:
idx = node.arg_names.index(name)
# we want the idx-th variable to be narrowed
expr = collapse_walrus(node.args[idx])
else:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
# Note: we wrap the target type, so that we can special case later.
# Namely, for isinstance() we use a normal meet, while TypeGuard is
Expand Down
14 changes: 14 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,20 @@ def analyze_func_def(self, defn: FuncDef) -> None:
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
# type guards need to have a positional argument, to spec
if (
result.type_guard
and ARG_POS not in result.arg_kinds[self.is_class_scope() :]
and not defn.is_static
):
self.fail(
"TypeGuard functions must have a positional argument",
result,
code=codes.VALID_TYPE,
)
# in this case, we just kind of just ... remove the type guard.
result = result.copy_modified(type_guard=None)

result = self.remove_unpack_kwargs(defn, result)
if has_self_type and self.type is not None:
info = self.type
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,34 @@ class C(Generic[T]):
main:10: note: Revealed type is "builtins.int"
main:10: note: Revealed type is "builtins.str"

[case testTypeGuardWithPositionalOnlyArg]
# flags: --python-version 3.8
from typing_extensions import TypeGuard

def typeguard(x: object, /) -> TypeGuard[int]:
...

n: object
if typeguard(n):
reveal_type(n)
[builtins fixtures/tuple.pyi]
[out]
main:9: note: Revealed type is "builtins.int"

[case testTypeGuardKeywordFollowingWalrus]
# flags: --python-version 3.8
from typing import cast
from typing_extensions import TypeGuard

def typeguard(x: object) -> TypeGuard[int]:
...

if typeguard(x=(n := cast(object, "hi"))):
reveal_type(n)
[builtins fixtures/tuple.pyi]
[out]
main:9: note: Revealed type is "builtins.int"

[case testNoCrashOnAssignmentExprClass]
class C:
[(j := i) for i in [1, 2, 3]] # E: Assignment expression within a comprehension cannot be used in a class body
Expand Down
88 changes: 81 additions & 7 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ reveal_type(foo) # N: Revealed type is "def (a: builtins.object) -> TypeGuard[b
[case testTypeGuardCallArgsNone]
from typing_extensions import TypeGuard
class Point: pass
# TODO: error on the 'def' line (insufficient args for type guard)
def is_point() -> TypeGuard[Point]: pass

def is_point() -> TypeGuard[Point]: pass # E: TypeGuard functions must have a positional argument
def main(a: object) -> None:
if is_point():
reveal_type(a) # N: Revealed type is "builtins.object"
Expand Down Expand Up @@ -227,13 +227,13 @@ def main(a: object) -> None:
from typing_extensions import TypeGuard
def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass
def main1(a: object) -> None:
# This is debatable -- should we support these cases?
if is_float(a=a, b=1):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(a=a, b=1): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
if is_float(b=1, a=a):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(b=1, a=a): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
# This is debatable -- should we support these cases?

ta = (a,)
if is_float(*ta): # E: Type guard requires positional argument
Expand Down Expand Up @@ -597,3 +597,77 @@ def func(names: Tuple[str, ...]):
if is_two_element_tuple(names):
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeGuardErroneousDefinitionFails]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, *, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...

def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithKeywordArg]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, x: object) -> TypeGuard[int]:
...

def typeguard(x: object) -> TypeGuard[int]:
...

n: object
if typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"

if Z().typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testStaticMethodTypeGuard]
from typing_extensions import TypeGuard

class Y:
@staticmethod
def typeguard(h: object) -> TypeGuard[int]:
...

x: object
if Y().typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
if Y.typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
[builtins fixtures/classmethod.pyi]

[case testTypeGuardKwargFollowingThroughOverloaded]
from typing import overload, Union
from typing_extensions import TypeGuard

@overload
def typeguard(x: object, y: str) -> TypeGuard[str]:
...

@overload
def typeguard(x: object, y: int) -> TypeGuard[int]:
...

def typeguard(x: object, y: Union[int, str]) -> Union[TypeGuard[int], TypeGuard[str]]:
...

x: object
if typeguard(x=x, y=42):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(y=42, x=x):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(x=x, y="42"):
reveal_type(x) # N: Revealed type is "builtins.str"

if typeguard(y="42", x=x):
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]

0 comments on commit 28c67cb

Please sign in to comment.