Skip to content

Commit

Permalink
dataclasses.replace: fall through to typeshed sig
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed Aug 26, 2023
1 parent 7f65cc7 commit a657d10
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
25 changes: 1 addition & 24 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,25 +966,6 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
)


def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None:
t_name = format_type_bare(t, ctx.api.options)
if parent_t is t:
msg = (
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass'
)
else:
pt_name = format_type_bare(parent_t, ctx.api.options)
msg = (
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass'
)

ctx.api.fail(msg, ctx.context)


def _get_expanded_dataclasses_fields(
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType
) -> list[CallableType] | None:
Expand All @@ -993,9 +974,7 @@ def _get_expanded_dataclasses_fields(
For generic classes, the field types are expanded.
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
elif isinstance(typ, UnionType):
if isinstance(typ, UnionType):
ret: list[CallableType] | None = []
for item in typ.relevant_items():
item = get_proper_type(item)
Expand All @@ -1012,14 +991,12 @@ def _get_expanded_dataclasses_fields(
elif isinstance(typ, Instance):
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
if replace_sym is None:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None
replace_sig = replace_sym.type
assert isinstance(replace_sig, ProperType)
assert isinstance(replace_sig, CallableType)
return [expand_type_by_instance(replace_sig, typ)]
else:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None


Expand Down
44 changes: 35 additions & 9 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,8 @@ a2 = replace(a, x='42', q=42) # E: Argument "x" to "replace" of "A" has incompa
a2 = replace(a, q='42') # E: Argument "q" to "replace" of "A" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A"

[builtins fixtures/tuple.pyi]

[case testReplaceUnion]
from typing import Generic, Union, TypeVar
from dataclasses import dataclass, replace, InitVar
Expand Down Expand Up @@ -2135,7 +2137,7 @@ _ = replace(a_or_b, x=42, y=True, z='42', init_var=42) # E: Argument "z" to "re
_ = replace(a_or_b, x=42, y=True, w={}, init_var=42) # E: Argument "w" to "replace" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>
_ = replace(a_or_b, y=42, init_var=42) # E: Argument "y" to "replace" of "Union[A[int], B]" has incompatible type "int"; expected "bool"

[builtins fixtures/dataclasses.pyi]
[builtins fixtures/tuple.pyi]

[case testReplaceUnionOfTypeVar]
from typing import Generic, Union, TypeVar
Expand All @@ -2155,7 +2157,9 @@ TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: Union[TA, TB, int]) -> None:
a2 = replace(b_or_t) # E: Argument 1 to "replace" has type "Union[TA, TB, int]" whose item "TB" is not bound to a dataclass # E: Argument 1 to "replace" has incompatible type "Union[TA, TB, int]" whose item "int" is not a dataclass
a2 = replace(b_or_t) # E: Value of type variable "_DataclassT" of "replace" cannot be "Union[TA, TB, int]"

[builtins fixtures/tuple.pyi]

[case testReplaceTypeVarBoundNotDataclass]
from dataclasses import dataclass, replace
Expand All @@ -2167,16 +2171,18 @@ TNone = TypeVar('TNone', bound=None)
TUnion = TypeVar('TUnion', bound=Union[str, int])

def f1(t: TInt) -> None:
_ = replace(t, x=42) # E: Argument 1 to "replace" has a variable type "TInt" not bound to a dataclass
_ = replace(t, x=42) # E: Value of type variable "_DataclassT" of "replace" cannot be "TInt"

def f2(t: TAny) -> TAny:
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TAny" not bound to a dataclass
return replace(t, x='spam') # E: Value of type variable "_DataclassT" of "replace" cannot be "TAny"

def f3(t: TNone) -> TNone:
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TNone" not bound to a dataclass
return replace(t, x='spam') # E: Value of type variable "_DataclassT" of "replace" cannot be "TNone"

def f4(t: TUnion) -> TUnion:
return replace(t, x='spam') # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "str" is not a dataclass # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "int" is not a dataclass
return replace(t, x='spam') # E: Value of type variable "_DataclassT" of "replace" cannot be "TUnion"

[builtins fixtures/tuple.pyi]

[case testReplaceTypeVarBound]
from dataclasses import dataclass, replace
Expand All @@ -2201,6 +2207,8 @@ def f(t: TA) -> TA:
f(A(x=42))
f(B(x=42))

[builtins fixtures/tuple.pyi]

[case testReplaceAny]
from dataclasses import replace
from typing import Any
Expand All @@ -2209,17 +2217,33 @@ a: Any
a2 = replace(a)
reveal_type(a2) # N: Revealed type is "Any"

[builtins fixtures/tuple.pyi]

[case testReplaceNotDataclass]
from dataclasses import replace

replace(5) # E: Argument 1 to "replace" has incompatible type "int"; expected a dataclass
replace(5) # E: Value of type variable "_DataclassT" of "replace" cannot be "int"

class C:
pass

replace(C()) # E: Argument 1 to "replace" has incompatible type "C"; expected a dataclass
replace(C()) # E: Value of type variable "_DataclassT" of "replace" cannot be "C"

replace(None) # E: Argument 1 to "replace" has incompatible type "None"; expected a dataclass
replace(None) # E: Value of type variable "_DataclassT" of "replace" cannot be "None"

[builtins fixtures/tuple.pyi]

[case testReplaceIsDataclass]
from dataclasses import is_dataclass, replace

def f(x: object) -> None:
# error before type-guard
y = replace(x) # E: Value of type variable "_DataclassT" of "replace" cannot be "object"
# no error after type-guard
if is_dataclass(x) and not isinstance(x, type):
y = replace(x)

[builtins fixtures/tuple.pyi]

[case testReplaceGeneric]
from dataclasses import dataclass, replace, InitVar
Expand All @@ -2238,6 +2262,8 @@ reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
a2 = replace(a, x='42') # E: Argument "x" to "replace" of "A[int]" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

[builtins fixtures/tuple.pyi]

[case testPostInitCorrectSignature]
from typing import Any, Generic, TypeVar, Callable, Self
from dataclasses import dataclass, InitVar
Expand Down
6 changes: 5 additions & 1 deletion test-data/unit/lib-stub/_typeshed.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Protocol, TypeVar, Iterable
from dataclasses import Field
from typing import Any, ClassVar, Protocol, TypeVar, Iterable

_KT = TypeVar("_KT")
_VT_co = TypeVar("_VT_co", covariant=True)

class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
def keys(self) -> Iterable[_KT]: pass
def __getitem__(self, __key: _KT) -> _VT_co: pass

class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
15 changes: 13 additions & 2 deletions test-data/unit/lib-stub/dataclasses.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Callable, Generic, Mapping, Optional, TypeVar, overload, Type
from _typeshed import DataclassInstance
from typing import Any, Callable, Generic, Literal, Mapping, Optional, TypeVar, overload, Type
from typing_extensions import TypeGuard

_T = TypeVar('_T')
_DataclassT = TypeVar("_DataclassT", bound=DataclassInstance)

class InitVar(Generic[_T]):
...
Expand Down Expand Up @@ -33,4 +36,12 @@ def field(*,

class Field(Generic[_T]): pass

def replace(__obj: _T, **changes: Any) -> _T: ...
@overload
def is_dataclass(obj: DataclassInstance) -> Literal[True]: ...
@overload
def is_dataclass(obj: type) -> TypeGuard[type[DataclassInstance]]: ...
@overload
def is_dataclass(obj: object) -> TypeGuard[DataclassInstance | type[DataclassInstance]]: ...


def replace(__obj: _DataclassT, **changes: Any) -> _DataclassT: ...

0 comments on commit a657d10

Please sign in to comment.