Skip to content

Improve error handling for dataclass inheritance #13531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,20 @@ def transform(self) -> bool:
tvar_def=order_tvar_def,
)

parent_decorator_arguments = []
for parent in info.mro[1:-1]:
parent_args = parent.metadata.get("dataclass")
if parent_args:
parent_decorator_arguments.append(parent_args)

if decorator_arguments["frozen"]:
if any(not parent["frozen"] for parent in parent_decorator_arguments):
ctx.api.fail("Cannot inherit frozen dataclass from a non-frozen one", info)
self._propertize_callables(attributes, settable=False)
self._freeze(attributes)
else:
if any(parent["frozen"] for parent in parent_decorator_arguments):
ctx.api.fail("Cannot inherit non-frozen dataclass from a frozen one", info)
self._propertize_callables(attributes)

if decorator_arguments["slots"]:
Expand Down Expand Up @@ -446,6 +456,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# copy() because we potentially modify all_attrs below and if this code requires debugging
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
known_super_attrs = set()
for info in cls.info.mro[1:-1]:
if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata:
# We haven't processed the base class yet. Need another pass.
Expand All @@ -467,6 +478,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
with state.strict_optional_set(ctx.api.options.strict_optional):
attr.expand_typevar_from_subtype(ctx.cls.info)
known_attrs.add(name)
known_super_attrs.add(name)
super_attrs.append(attr)
elif all_attrs:
# How early in the attribute list an attribute appears is determined by the
Expand All @@ -481,6 +493,14 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
all_attrs = super_attrs + all_attrs
all_attrs.sort(key=lambda a: a.kw_only)

for known_super_attr_name in known_super_attrs:
sym_node = cls.info.names.get(known_super_attr_name)
if sym_node and sym_node.node and not isinstance(sym_node.node, Var):
ctx.api.fail(
"Dataclass attribute may only be overridden by another attribute",
sym_node.node,
)

# Ensure that arguments without a default don't follow
# arguments that have a default.
found_default = False
Expand Down Expand Up @@ -515,8 +535,8 @@ def _freeze(self, attributes: list[DataclassAttribute]) -> None:
sym_node = info.names.get(attr.name)
if sym_node is not None:
var = sym_node.node
assert isinstance(var, Var)
var.is_property = True
if isinstance(var, Var):
var.is_property = True
else:
var = attr.to_var()
var.info = info
Expand Down
86 changes: 84 additions & 2 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,66 @@ reveal_type(C) # N: Revealed type is "def (some_int: builtins.int, some_str: bu

[builtins fixtures/dataclasses.pyi]

[case testDataclassIncompatibleOverrides]
# flags: --python-version 3.7
from dataclasses import dataclass

@dataclass
class Base:
foo: int

@dataclass
class BadDerived1(Base):
def foo(self) -> int: # E: Dataclass attribute may only be overridden by another attribute \
# E: Signature of "foo" incompatible with supertype "Base"
return 1

@dataclass
class BadDerived2(Base):
@property # E: Dataclass attribute may only be overridden by another attribute
def foo(self) -> int: # E: Cannot override writeable attribute with read-only property
return 2

@dataclass
class BadDerived3(Base):
class foo: pass # E: Dataclass attribute may only be overridden by another attribute
[builtins fixtures/dataclasses.pyi]

[case testDataclassMultipleInheritance]
# flags: --python-version 3.7
from dataclasses import dataclass

class Unrelated:
foo: str

@dataclass
class Base:
bar: int

@dataclass
class Derived(Base, Unrelated):
pass

d = Derived(3)
reveal_type(d.foo) # N: Revealed type is "builtins.str"
reveal_type(d.bar) # N: Revealed type is "builtins.int"
[builtins fixtures/dataclasses.pyi]

[case testDataclassIncompatibleFrozenOverride]
# flags: --python-version 3.7
from dataclasses import dataclass

@dataclass(frozen=True)
class Base:
foo: int

@dataclass(frozen=True)
class BadDerived(Base):
@property # E: Dataclass attribute may only be overridden by another attribute
def foo(self) -> int:
return 3
[builtins fixtures/dataclasses.pyi]

[case testDataclassesFreezing]
# flags: --python-version 3.7
from dataclasses import dataclass
Expand All @@ -200,6 +260,28 @@ john.name = 'Ben' # E: Property "name" defined in "Person" is read-only

[builtins fixtures/dataclasses.pyi]

[case testDataclassesInconsistentFreezing]
# flags: --python-version 3.7
from dataclasses import dataclass

@dataclass(frozen=True)
class FrozenBase:
pass

@dataclass
class BadNormalDerived(FrozenBase): # E: Cannot inherit non-frozen dataclass from a frozen one
pass

@dataclass
class NormalBase:
pass

@dataclass(frozen=True)
class BadFrozenDerived(NormalBase): # E: Cannot inherit frozen dataclass from a non-frozen one
pass

[builtins fixtures/dataclasses.pyi]

[case testDataclassesFields]
# flags: --python-version 3.7
from dataclasses import dataclass, field
Expand Down Expand Up @@ -1283,9 +1365,9 @@ from dataclasses import dataclass
class A:
foo: int

@dataclass
@dataclass(frozen=True)
class B(A):
@property
@property # E: Dataclass attribute may only be overridden by another attribute
def foo(self) -> int: pass

reveal_type(B) # N: Revealed type is "def (foo: builtins.int) -> __main__.B"
Expand Down