Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More helpful type guards #14238

Merged
merged 9 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
PR review
  • Loading branch information
A5rocks committed Dec 5, 2022
commit 3ea4baf3c1ee4de5c4fecc872e36ab35e585cf89
11 changes: 9 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5288,8 +5288,15 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
if node.arg_kinds[0] != nodes.ARG_POS:
# the first argument might be used as a kwarg
called_type = get_proper_type(self.lookup_type(node.callee))
assert isinstance(called_type, CallableType)
name = called_type.arg_names[0]
assert isinstance(called_type, (CallableType, Overloaded))

# *assuming* the overloaded function is correct, there's a couple cases:
# 1) The first argument has different names, but is pos-only. We don't
# care about this case, the argument must be passed positionally.
# 2) The first argument allows keyword reference, therefore must be the
# same between overloads.
name = called_type.items[0].arg_names[0]

if name in node.arg_names:
idx = node.arg_names.index(name)
# we want the idx-th variable to be narrowed
Expand Down
75 changes: 64 additions & 11 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,50 @@ class Z:

def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithKeywordArg]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, x: object) -> TypeGuard[int]:
...

def typeguard(x: object) -> TypeGuard[int]:
...

n: object
if typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"

if Z().typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testTypeGuardWithPositionalOnlyArg]
from typing_extensions import TypeGuard

def typeguard(x: object, /) -> TypeGuard[int]:
...

n: object
if typeguard(n):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testTypeGuardKeywordFollowingWalrus]
from typing import cast
from typing_extensions import TypeGuard

def typeguard(x: object) -> TypeGuard[int]:
...

if typeguard(x=(n := cast(object, "hi"))):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

# make sure not to break other things
[case testStaticMethodTypeGuard]
from typing_extensions import TypeGuard

class Y:
@staticmethod
Expand All @@ -623,20 +665,31 @@ if Y.typeguard(x):
[builtins fixtures/tuple.pyi]
[builtins fixtures/classmethod.pyi]

[case testTypeGuardWithKeywordArg]
[case testTypeGuardKwargFollowingThroughOverloaded]
from typing import overload, Union
from typing_extensions import TypeGuard

class Z:
def typeguard(self, x: object) -> TypeGuard[int]:
...
@overload
def typeguard(x: object, y: str) -> TypeGuard[str]:
...

def typeguard(x: object) -> TypeGuard[int]:
@overload
def typeguard(x: object, y: int) -> TypeGuard[int]:
...

n: object
if typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
def typeguard(x: object, y: Union[int, str]) -> Union[TypeGuard[int], TypeGuard[str]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this return type makes sense. But can add an error in another PR

...

if Z().typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
x: object
if typeguard(x=x, y=42):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(y=42, x=x):
reveal_type(x) # N: Revealed type is "builtins.int"

if typeguard(x=x, y="42"):
reveal_type(x) # N: Revealed type is "builtins.str"

if typeguard(y="42", x=x):
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]