Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improvements to functools.partial handling #17425

Merged
merged 7 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,11 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
partial_names.append(fn_type.arg_names[i])
elif actuals:
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
# Don't add params for arguments passed positionally
continue
# Add defaulted params for arguments passed via keyword
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
partial_types.append(arg_type)
Expand Down Expand Up @@ -268,15 +270,25 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]
# See comments for similar actual to formal code above
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
seen_args = set()
for i, param in enumerate(ctx.args):
for j, a in enumerate(param):
if a in seen_args:
continue
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this it seems to me a better strategy for the call site may be using get_attribute_hook() for __call__? Unfortunately this hook is not called in is_subtype() etc yet. But at least it will be possible to precisely type-check something like this in future

def foo(fn: Callable[[int, str], int]) -> None: ...
fn = partial(some_other_fn, 1, 2)
foo(fn)


result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result[0]
1 change: 1 addition & 0 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def visit_instance(self, t: Instance) -> Type:
line=t.line,
column=t.column,
last_known_value=last_known_value,
extra_attrs=t.extra_attrs,
)

def visit_type_var(self, t: TypeVarType) -> Type:
Expand Down
3 changes: 1 addition & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,8 +1417,7 @@ def __init__(
self._hash = -1

# Additional attributes defined per instance of this type. For example modules
# have different attributes per instance of types.ModuleType. This is intended
# to be "short-lived", we don't serialize it, and even don't store as variable type.
# have different attributes per instance of types.ModuleType.
self.extra_attrs = extra_attrs

def accept(self, visitor: TypeVisitor[T]) -> T:
Expand Down
88 changes: 63 additions & 25 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -408,33 +408,71 @@ def foo(cls3: Type[B[T]]):
from typing_extensions import TypedDict, Unpack
from functools import partial

class Data(TypedDict, total=False):
x: int

def f(**kwargs: Unpack[Data]) -> None: ...
def g(**kwargs: Unpack[Data]) -> None:
partial(f, **kwargs)()

class MoreData(TypedDict, total=False):
x: int
y: int

def f_more(**kwargs: Unpack[MoreData]) -> None: ...
def g_more(**kwargs: Unpack[MoreData]) -> None:
partial(f_more, **kwargs)()

class Good(TypedDict, total=False):
y: int
class Bad(TypedDict, total=False):
y: str

def h(**kwargs: Unpack[Data]) -> None:
bad: Bad
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
good: Good
partial(f_more, **kwargs)(**good)
hauntsaninja marked this conversation as resolved.
Show resolved Hide resolved
class D1(TypedDict, total=False):
a1: int

def fn1(a1: int) -> None: ... # N: "fn1" defined here
def main1(**kwargs: Unpack[D1]) -> None:
partial(fn1, **kwargs)()
partial(fn1, **kwargs)(**kwargs)
partial(fn1, **kwargs)(a1=1)
partial(fn1, **kwargs)(a1="asdf") # E: Argument "a1" to "fn1" has incompatible type "str"; expected "int"
partial(fn1, **kwargs)(oops=1) # E: Unexpected keyword argument "oops" for "fn1"

def fn2(**kwargs: Unpack[D1]) -> None: ... # N: "fn2" defined here
def main2(**kwargs: Unpack[D1]) -> None:
partial(fn2, **kwargs)()
partial(fn2, **kwargs)(**kwargs)
partial(fn2, **kwargs)(a1=1)
partial(fn2, **kwargs)(a1="asdf") # E: Argument "a1" to "fn2" has incompatible type "str"; expected "int"
partial(fn2, **kwargs)(oops=1) # E: Unexpected keyword argument "oops" for "fn2"

class D2(TypedDict, total=False):
a1: int
a2: str

class A2Good(TypedDict, total=False):
a2: str
class A2Bad(TypedDict, total=False):
a2: int

def fn3(a1: int, a2: str) -> None: ... # N: "fn3" defined here
def main3(**kwargs: Unpack[D2]) -> None:
partial(fn3, **kwargs)()
partial(fn3, **kwargs)(a1=1, a2="asdf")

partial(fn3, **kwargs)(**kwargs)

partial(fn3, **kwargs)(a1="asdf") # E: Argument "a1" to "fn3" has incompatible type "str"; expected "int"
partial(fn3, **kwargs)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn3"

a2good: A2Good
partial(fn3, **kwargs)(**a2good)
a2bad: A2Bad
partial(fn3, **kwargs)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"

def fn4(**kwargs: Unpack[D2]) -> None: ... # N: "fn4" defined here
def main4(**kwargs: Unpack[D2]) -> None:
partial(fn4, **kwargs)()
partial(fn4, **kwargs)(a1=1, a2="asdf")

partial(fn4, **kwargs)(**kwargs)

partial(fn4, **kwargs)(a1="asdf") # E: Argument "a1" to "fn4" has incompatible type "str"; expected "int"
partial(fn4, **kwargs)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn4"

a2good: A2Good
partial(fn3, **kwargs)(**a2good)
a2bad: A2Bad
partial(fn3, **kwargs)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"


def main5(**kwargs: Unpack[D2]) -> None:
partial(fn1, **kwargs)() # E: Extra argument "a2" from **args for "fn1"
partial(fn2, **kwargs)() # E: Extra argument "a2" from **args for "fn2"
[builtins fixtures/dict.pyi]


[case testFunctoolsPartialNestedGeneric]
from functools import partial
from typing import Generic, TypeVar, List
Expand Down
48 changes: 48 additions & 0 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -10497,3 +10497,51 @@ from pkg.sub import modb

[out]
==

[case testFineGrainedFunctoolsPartial]
import m

[file m.py]
from typing import Callable
from partial import p1

reveal_type(p1)
p1("a")
p1("a", 3)
p1("a", c=3)
p1(1, 3)
p1(1, "a", 3)
p1(a=1, b="a", c=3)
[builtins fixtures/dict.pyi]

[file partial.py]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ...
p1 = foo

[file partial.py.2]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ...
p1 = functools.partial(foo, 1)

[out]
m.py:4: note: Revealed type is "def (a: builtins.int, b: builtins.str, c: builtins.int =) -> builtins.int"
m.py:5: error: Too few arguments
m.py:5: error: Argument 1 has incompatible type "str"; expected "int"
m.py:6: error: Argument 1 has incompatible type "str"; expected "int"
m.py:6: error: Argument 2 has incompatible type "int"; expected "str"
m.py:7: error: Too few arguments
m.py:7: error: Argument 1 has incompatible type "str"; expected "int"
m.py:8: error: Argument 2 has incompatible type "int"; expected "str"
==
m.py:4: note: Revealed type is "functools.partial[builtins.int]"
m.py:8: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
m.py:9: error: Too many arguments for "foo"
m.py:9: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
m.py:9: error: Argument 2 to "foo" has incompatible type "str"; expected "int"
m.py:10: error: Unexpected keyword argument "a" for "foo"
partial.py:4: note: "foo" defined here
Loading