Skip to content

Commit

Permalink
Further improvements to functools.partial handling (#17425)
Browse files Browse the repository at this point in the history
- Fixes another crash case / type inference in that case
- Fix a false positive when calling the partially applied function with
kwargs
- TypeTraverse / comment / daemon test follow up ilevkivskyi mentioned
on the original PR

See also #17423
  • Loading branch information
hauntsaninja authored Jul 1, 2024
1 parent 4ae632b commit d1d3c78
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 35 deletions.
31 changes: 22 additions & 9 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,14 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
elif actuals:
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
else:
assert actuals
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
# Don't add params for arguments passed positionally
continue
# Add defaulted params for arguments passed via keyword
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
Expand Down Expand Up @@ -286,15 +289,25 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]
# See comments for similar actual to formal code above
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
seen_args = set()
for i, param in enumerate(ctx.args):
for j, a in enumerate(param):
if a in seen_args:
continue
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])

result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result[0]
1 change: 1 addition & 0 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def visit_instance(self, t: Instance) -> Type:
line=t.line,
column=t.column,
last_known_value=last_known_value,
extra_attrs=t.extra_attrs,
)

def visit_type_var(self, t: TypeVarType) -> Type:
Expand Down
3 changes: 1 addition & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,8 +1417,7 @@ def __init__(
self._hash = -1

# Additional attributes defined per instance of this type. For example modules
# have different attributes per instance of types.ModuleType. This is intended
# to be "short-lived", we don't serialize it, and even don't store as variable type.
# have different attributes per instance of types.ModuleType.
self.extra_attrs = extra_attrs

def accept(self, visitor: TypeVisitor[T]) -> T:
Expand Down
121 changes: 97 additions & 24 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ functools.partial(1) # E: "int" not callable \

[case testFunctoolsPartialStar]
import functools
from typing import List

def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...

Expand All @@ -215,6 +216,13 @@ def bar(*a: bytes, **k: int):
p1("a", **k) # E: Argument 2 to "foo" has incompatible type "**Dict[str, int]"; expected "str"
p1(**k) # E: Argument 1 to "foo" has incompatible type "**Dict[str, int]"; expected "str"
p1(*a) # E: List or tuple expected as variadic arguments


def baz(a: int, b: int) -> int: ...
def test_baz(xs: List[int]):
p3 = functools.partial(baz, *xs)
p3()
p3(1) # E: Too many arguments for "baz"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialGeneric]
Expand Down Expand Up @@ -408,33 +416,83 @@ def foo(cls3: Type[B[T]]):
from typing_extensions import TypedDict, Unpack
from functools import partial

class Data(TypedDict, total=False):
x: int

def f(**kwargs: Unpack[Data]) -> None: ...
def g(**kwargs: Unpack[Data]) -> None:
partial(f, **kwargs)()

class MoreData(TypedDict, total=False):
x: int
y: int
class D1(TypedDict, total=False):
a1: int

def fn1(a1: int) -> None: ... # N: "fn1" defined here
def main1(**d1: Unpack[D1]) -> None:
partial(fn1, **d1)()
partial(fn1, **d1)(**d1)
partial(fn1, **d1)(a1=1)
partial(fn1, **d1)(a1="asdf") # E: Argument "a1" to "fn1" has incompatible type "str"; expected "int"
partial(fn1, **d1)(oops=1) # E: Unexpected keyword argument "oops" for "fn1"

def fn2(**kwargs: Unpack[D1]) -> None: ... # N: "fn2" defined here
def main2(**d1: Unpack[D1]) -> None:
partial(fn2, **d1)()
partial(fn2, **d1)(**d1)
partial(fn2, **d1)(a1=1)
partial(fn2, **d1)(a1="asdf") # E: Argument "a1" to "fn2" has incompatible type "str"; expected "int"
partial(fn2, **d1)(oops=1) # E: Unexpected keyword argument "oops" for "fn2"

class D2(TypedDict, total=False):
a1: int
a2: str

class A2Good(TypedDict, total=False):
a2: str
class A2Bad(TypedDict, total=False):
a2: int

def fn3(a1: int, a2: str) -> None: ... # N: "fn3" defined here
def main3(a2good: A2Good, a2bad: A2Bad, **d2: Unpack[D2]) -> None:
partial(fn3, **d2)()
partial(fn3, **d2)(a1=1, a2="asdf")

partial(fn3, **d2)(**d2)

partial(fn3, **d2)(a1="asdf") # E: Argument "a1" to "fn3" has incompatible type "str"; expected "int"
partial(fn3, **d2)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn3"

partial(fn3, **d2)(**a2good)
partial(fn3, **d2)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"

def fn4(**kwargs: Unpack[D2]) -> None: ... # N: "fn4" defined here
def main4(a2good: A2Good, a2bad: A2Bad, **d2: Unpack[D2]) -> None:
partial(fn4, **d2)()
partial(fn4, **d2)(a1=1, a2="asdf")

partial(fn4, **d2)(**d2)

partial(fn4, **d2)(a1="asdf") # E: Argument "a1" to "fn4" has incompatible type "str"; expected "int"
partial(fn4, **d2)(a1=1, a2="asdf", oops=1) # E: Unexpected keyword argument "oops" for "fn4"

partial(fn3, **d2)(**a2good)
partial(fn3, **d2)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"

def main5(**d2: Unpack[D2]) -> None:
partial(fn1, **d2)() # E: Extra argument "a2" from **args for "fn1"
partial(fn2, **d2)() # E: Extra argument "a2" from **args for "fn2"

def main6(a2good: A2Good, a2bad: A2Bad, **d1: Unpack[D1]) -> None:
partial(fn3, **d1)() # E: Missing positional argument "a1" in call to "fn3"
partial(fn3, **d1)("asdf") # E: Too many positional arguments for "fn3" \
# E: Too few arguments for "fn3" \
# E: Argument 1 to "fn3" has incompatible type "str"; expected "int"
partial(fn3, **d1)(a2="asdf")
partial(fn3, **d1)(**a2good)
partial(fn3, **d1)(**a2bad) # E: Argument "a2" to "fn3" has incompatible type "int"; expected "str"

partial(fn4, **d1)()
partial(fn4, **d1)("asdf") # E: Too many positional arguments for "fn4" \
# E: Argument 1 to "fn4" has incompatible type "str"; expected "int"
partial(fn4, **d1)(a2="asdf")
partial(fn4, **d1)(**a2good)
partial(fn4, **d1)(**a2bad) # E: Argument "a2" to "fn4" has incompatible type "int"; expected "str"

def f_more(**kwargs: Unpack[MoreData]) -> None: ...
def g_more(**kwargs: Unpack[MoreData]) -> None:
partial(f_more, **kwargs)()

class Good(TypedDict, total=False):
y: int
class Bad(TypedDict, total=False):
y: str

def h(**kwargs: Unpack[Data]) -> None:
bad: Bad
partial(f_more, **kwargs)(**bad) # E: Argument "y" to "f_more" has incompatible type "str"; expected "int"
good: Good
partial(f_more, **kwargs)(**good)
[builtins fixtures/dict.pyi]


[case testFunctoolsPartialNestedGeneric]
from functools import partial
from typing import Generic, TypeVar, List
Expand All @@ -456,6 +514,21 @@ first_kw([1]) # E: Too many positional arguments for "get" \
# E: Argument 1 to "get" has incompatible type "List[int]"; expected "int"
[builtins fixtures/list.pyi]

[case testFunctoolsPartialHigherOrder]
from functools import partial
from typing import Callable

def fn(a: int, b: str, c: bytes) -> int: ...

def callback1(fn: Callable[[str, bytes], int]) -> None: ...
def callback2(fn: Callable[[str, int], int]) -> None: ...

callback1(partial(fn, 1))
# TODO: false negative
# https://github.com/python/mypy/issues/17461
callback2(partial(fn, 1))
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialClassObjectMatchingPartial]
from functools import partial

Expand Down
48 changes: 48 additions & 0 deletions test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -10497,3 +10497,51 @@ from pkg.sub import modb

[out]
==

[case testFineGrainedFunctoolsPartial]
import m

[file m.py]
from typing import Callable
from partial import p1

reveal_type(p1)
p1("a")
p1("a", 3)
p1("a", c=3)
p1(1, 3)
p1(1, "a", 3)
p1(a=1, b="a", c=3)
[builtins fixtures/dict.pyi]

[file partial.py]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ...
p1 = foo

[file partial.py.2]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ...
p1 = functools.partial(foo, 1)

[out]
m.py:4: note: Revealed type is "def (a: builtins.int, b: builtins.str, c: builtins.int =) -> builtins.int"
m.py:5: error: Too few arguments
m.py:5: error: Argument 1 has incompatible type "str"; expected "int"
m.py:6: error: Argument 1 has incompatible type "str"; expected "int"
m.py:6: error: Argument 2 has incompatible type "int"; expected "str"
m.py:7: error: Too few arguments
m.py:7: error: Argument 1 has incompatible type "str"; expected "int"
m.py:8: error: Argument 2 has incompatible type "int"; expected "str"
==
m.py:4: note: Revealed type is "functools.partial[builtins.int]"
m.py:8: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
m.py:9: error: Too many arguments for "foo"
m.py:9: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
m.py:9: error: Argument 2 to "foo" has incompatible type "str"; expected "int"
m.py:10: error: Unexpected keyword argument "a" for "foo"
partial.py:4: note: "foo" defined here

0 comments on commit d1d3c78

Please sign in to comment.