Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for functools.partial #16939

Merged
merged 14 commits into from
May 23, 2024
Prev Previous commit
Next Next commit
server and incremental
  • Loading branch information
hauntsaninja committed May 21, 2024
commit c7f6d783cbab6bfefc570cfe3f2c57daba298623
3 changes: 3 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def visit_instance(self, inst: Instance) -> None:
a.accept(self)
if inst.last_known_value is not None:
inst.last_known_value.accept(self)
if inst.extra_attrs:
for v in inst.extra_attrs.attrs.values():
v.accept(self)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
type_ref = t.type_ref
Expand Down
8 changes: 8 additions & 0 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,19 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
return snapshot_simple_type(typ)

def visit_instance(self, typ: Instance) -> SnapshotItem:
if self.extra_attrs:
extra_attrs = (
tuple(sorted((k, self.visit(v)) for k, v in self.extra_attrs.attrs.items())),
tuple(self.extra_attrs.immutable),
)
else:
extra_attrs = ()
return (
"Instance",
encode_optional_str(typ.type.fullname),
snapshot_types(typ.args),
("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value),
extra_attrs,
)

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
Expand Down
20 changes: 20 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,23 @@ def copy(self) -> ExtraAttrs:
def __repr__(self) -> str:
return f"ExtraAttrs({self.attrs!r}, {self.immutable!r}, {self.mod_name!r})"

def serialize(self) -> JsonDict:
return {
".class": "ExtraAttrs",
"attrs": {k: v.serialize() for k, v in self.attrs.items()},
"immutable": list(self.immutable),
"mod_name": self.mod_name,
}

@classmethod
def deserialize(cls, data: JsonDict) -> ExtraAttrs:
assert data[".class"] == "ExtraAttrs"
return ExtraAttrs(
{k: deserialize_type(v) for k, v in data["attrs"].items()},
set(data["immutable"]),
data["mod_name"],
)


class Instance(ProperType):
"""An instance type of form C[T1, ..., Tn].
Expand Down Expand Up @@ -1468,6 +1485,7 @@ def serialize(self) -> JsonDict | str:
data["args"] = [arg.serialize() for arg in self.args]
if self.last_known_value is not None:
data["last_known_value"] = self.last_known_value.serialize()
data["extra_attrs"] = self.extra_attrs.serialize() if self.extra_attrs else None
return data

@classmethod
Expand All @@ -1486,6 +1504,8 @@ def deserialize(cls, data: JsonDict | str) -> Instance:
inst.type_ref = data["type_ref"] # Will be fixed up by fixup.py later.
if "last_known_value" in data:
inst.last_known_value = LiteralType.deserialize(data["last_known_value"])
if data.get("extra_attrs") is not None:
inst.extra_attrs = ExtraAttrs.deserialize(data["extra_attrs"])
return inst

def copy_modified(
Expand Down
60 changes: 60 additions & 0 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -6574,3 +6574,63 @@ class TheClass:
[out]
[out2]
tmp/a.py:3: note: Revealed type is "def (value: builtins.object) -> lib.TheClass.pyenum@6"


[case testIncrementalFunctoolsPartial]
import a

[file a.py]
from typing import Callable
from partial import p1, p2

p1(1, "a", 3) # OK
p1(1, "a", c=3) # OK
p1(1, b="a", c=3) # OK

reveal_type(p1)

def takes_callable_int(f: Callable[..., int]) -> None: ...
def takes_callable_str(f: Callable[..., str]) -> None: ...
takes_callable_int(p1)
takes_callable_str(p1)

p2("a") # OK
p2("a", 3) # OK
p2("a", c=3) # OK
p2(1, 3)
p2(1, "a", 3)
p2(a=1, b="a", c=3)

[file a.py.2]
from typing import Callable
from partial import p3

p3(1) # OK
p3(1, c=3) # OK
p3(a=1) # OK
p3(1, b="a", c=3) # OK, keywords can be clobbered
p3(1, 3)

[file partial.py]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ...

p1 = functools.partial(foo)
p2 = functools.partial(foo, 1)
p3 = functools.partial(foo, b="a")
[builtins fixtures/dict.pyi]
[out]
tmp/a.py:8: note: Revealed type is "functools.partial[builtins.int]"
tmp/a.py:13: error: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]"
tmp/a.py:13: note: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]"
tmp/a.py:18: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
tmp/a.py:19: error: Too many arguments for "foo"
tmp/a.py:19: error: Argument 1 to "foo" has incompatible type "int"; expected "str"
tmp/a.py:19: error: Argument 2 to "foo" has incompatible type "str"; expected "int"
tmp/a.py:20: error: Unexpected keyword argument "a" for "foo"
tmp/partial.py:4: note: "foo" defined here
[out2]
tmp/a.py:8: error: Too many positional arguments for "foo"
tmp/a.py:8: error: Argument 2 to "foo" has incompatible type "int"; expected "str"