Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 89 additions & 31 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ class _MISSING_TYPE:

# The name of an attribute on the class where we store the Field
# objects. Also used to check if a class is a Data Class.
_MARKER = '__dataclass_fields__'
_FIELDS = '__dataclass_fields__'

# The name of an attribute on the class that stores the parameters to
# @dataclass.
_PARAMS = '__dataclass_params__'

# The name of the function, that if it exists, is called at the end of
# __init__.
Expand All @@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
# name and type are filled in after the fact, not in __init__. They're
# not known at the time this class is instantiated, but it's
# convenient if they're available later.
# When cls._MARKER is filled in with a list of Field objects, the name
# When cls._FIELDS is filled in with a list of Field objects, the name
# and type fields will have been populated.
class Field:
__slots__ = ('name',
Expand Down Expand Up @@ -236,6 +240,32 @@ def __repr__(self):
')')


class _DataclassParams:
__slots__ = ('init',
'repr',
'eq',
'order',
'unsafe_hash',
'frozen',
)
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
self.init = init
self.repr = repr
self.eq = eq
self.order = order
self.unsafe_hash = unsafe_hash
self.frozen = frozen

def __repr__(self):
return ('_DataclassParams('
f'init={self.init},'
f'repr={self.repr},'
f'eq={self.eq},'
f'order={self.order},'
f'unsafe_hash={self.unsafe_hash},'
f'frozen={self.frozen}'
')')

# This function is used instead of exposing Field creation directly,
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
Expand Down Expand Up @@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
args = ','.join(args)
body = '\n'.join(f' {b}' for b in body)

# Compute the text of the entire function.
txt = f'def {name}({args}){return_annotation}:\n{body}'

exec(txt, globals, locals)
Expand Down Expand Up @@ -432,12 +463,29 @@ def _repr_fn(fields):
')"'])


def _frozen_setattr(self, name, value):
raise FrozenInstanceError(f'cannot assign to field {name!r}')


def _frozen_delattr(self, name):
raise FrozenInstanceError(f'cannot delete field {name!r}')
def _frozen_get_del_attr(cls, fields):
# XXX: globals is modified on the first call to _create_fn, then the
# modified version is used in the second call. Is this okay?
globals = {'cls': cls,
'FrozenInstanceError': FrozenInstanceError}
if fields:
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
else:
# Special case for the zero-length tuple.
fields_str = '()'
return (_create_fn('__setattr__',
('self', 'name', 'value'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
f'super(cls, self).__setattr__(name, value)'),
globals=globals),
_create_fn('__delattr__',
('self', 'name'),
(f'if type(self) is cls or name in {fields_str}:',
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
f'super(cls, self).__delattr__(name)'),
globals=globals),
)


def _cmp_fn(name, op, self_tuple, other_tuple):
Expand Down Expand Up @@ -583,23 +631,32 @@ def _set_new_attribute(cls, name, value):
# version of this table.


def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
# Now that dicts retain insertion order, there's no reason to use
# an ordered dict. I am leveraging that ordering here, because
# derived class fields overwrite base class fields, but the order
# is defined by the base class, which is found first.
fields = {}

setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
unsafe_hash, frozen))

# Find our base classes in reverse MRO order, and exclude
# ourselves. In reversed order so that more derived classes
# override earlier field definitions in base classes.
# As long as we're iterating over them, see if any are frozen.
any_frozen_base = False
has_dataclass_bases = False
for b in cls.__mro__[-1:0:-1]:
# Only process classes that have been processed by our
# decorator. That is, they have a _MARKER attribute.
base_fields = getattr(b, _MARKER, None)
# decorator. That is, they have a _FIELDS attribute.
base_fields = getattr(b, _FIELDS, None)
if base_fields:
has_dataclass_bases = True
for f in base_fields.values():
fields[f.name] = f
if getattr(b, _PARAMS).frozen:
any_frozen_base = True

# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes)
Expand All @@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
else:
setattr(cls, f.name, f.default)

# We're inheriting from a frozen dataclass, but we're not frozen.
if cls.__setattr__ is _frozen_setattr and not frozen:
raise TypeError('cannot inherit non-frozen dataclass from a '
'frozen one')
# Check rules that apply if we are derived from any dataclasses.
if has_dataclass_bases:
# Raise an exception if any of our bases are frozen, but we're not.
if any_frozen_base and not frozen:
raise TypeError('cannot inherit non-frozen dataclass from a '
'frozen one')

# We're inheriting from a non-frozen dataclass, but we're frozen.
if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
and frozen):
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')
# Raise an exception if we're frozen, but none of our bases are.
if not any_frozen_base and frozen:
raise TypeError('cannot inherit frozen dataclass from a '
'non-frozen one')

# Remember all of the fields on our class (including bases). This
# Remember all of the fields on our class (including bases). This also
# marks this class as being a dataclass.
setattr(cls, _MARKER, fields)
setattr(cls, _FIELDS, fields)

# Was this class defined with an explicit __hash__? Note that if
# __eq__ is defined in this class, then python will automatically
Expand Down Expand Up @@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
'functools.total_ordering')

if frozen:
for name, fn in [('__setattr__', _frozen_setattr),
('__delattr__', _frozen_delattr)]:
if _set_new_attribute(cls, name, fn):
raise TypeError(f'Cannot overwrite attribute {name} '
# XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
for fn in _frozen_get_del_attr(cls, field_list):
if _set_new_attribute(cls, fn.__name__, fn):
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
f'in class {cls.__name__}')

# Decide if/how we're going to create a hash function.
Expand Down Expand Up @@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
"""

def wrap(cls):
return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)

# See if we're being called as @dataclass or @dataclass().
if _cls is None:
Expand All @@ -779,7 +837,7 @@ def fields(class_or_instance):

# Might it be worth caching this, per class?
try:
fields = getattr(class_or_instance, _MARKER)
fields = getattr(class_or_instance, _FIELDS)
except AttributeError:
raise TypeError('must be called with a dataclass type or instance')

Expand All @@ -790,13 +848,13 @@ def fields(class_or_instance):

def _is_dataclass_instance(obj):
"""Returns True if obj is an instance of a dataclass."""
return not isinstance(obj, type) and hasattr(obj, _MARKER)
return not isinstance(obj, type) and hasattr(obj, _FIELDS)


def is_dataclass(obj):
"""Returns True if obj is a dataclass or an instance of a
dataclass."""
return hasattr(obj, _MARKER)
return hasattr(obj, _FIELDS)


def asdict(obj, *, dict_factory=dict):
Expand Down Expand Up @@ -953,7 +1011,7 @@ class C:
# It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj.

for f in getattr(obj, _MARKER).values():
for f in getattr(obj, _FIELDS).values():
if not f.init:
# Error if this field is specified in changes.
if f.name in changes:
Expand Down
99 changes: 75 additions & 24 deletions Lib/test/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,41 +2476,92 @@ class D(C):
d = D(0, 10)
with self.assertRaises(FrozenInstanceError):
d.i = 5
with self.assertRaises(FrozenInstanceError):
d.j = 6
self.assertEqual(d.i, 0)
self.assertEqual(d.j, 10)

# Test both ways: with an intermediate normal (non-dataclass)
# class and without an intermediate class.
def test_inherit_nonfrozen_from_frozen(self):
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
@dataclass(frozen=True)
class C:
i: int

def test_inherit_from_nonfrozen_from_frozen(self):
@dataclass(frozen=True)
class C:
i: int
if intermediate_class:
class I(C): pass
else:
I = C

with self.assertRaisesRegex(TypeError,
'cannot inherit non-frozen dataclass from a frozen one'):
@dataclass
class D(C):
pass
with self.assertRaisesRegex(TypeError,
'cannot inherit non-frozen dataclass from a frozen one'):
@dataclass
class D(I):
pass

def test_inherit_from_frozen_from_nonfrozen(self):
@dataclass
class C:
i: int
def test_inherit_frozen_from_nonfrozen(self):
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
@dataclass
class C:
i: int

with self.assertRaisesRegex(TypeError,
'cannot inherit frozen dataclass from a non-frozen one'):
@dataclass(frozen=True)
class D(C):
pass
if intermediate_class:
class I(C): pass
else:
I = C

with self.assertRaisesRegex(TypeError,
'cannot inherit frozen dataclass from a non-frozen one'):
@dataclass(frozen=True)
class D(I):
pass

def test_inherit_from_normal_class(self):
class C:
pass
for intermediate_class in [True, False]:
with self.subTest(intermediate_class=intermediate_class):
class C:
pass

if intermediate_class:
class I(C): pass
else:
I = C

@dataclass(frozen=True)
class D(I):
i: int

d = D(10)
with self.assertRaises(FrozenInstanceError):
d.i = 5

def test_non_frozen_normal_derived(self):
# See bpo-32953.

@dataclass(frozen=True)
class D(C):
i: int
class D:
x: int
y: int = 10

d = D(10)
class S(D):
pass

s = S(3)
self.assertEqual(s.x, 3)
self.assertEqual(s.y, 10)
s.cached = True

# But can't change the frozen attributes.
with self.assertRaises(FrozenInstanceError):
d.i = 5
s.x = 5
with self.assertRaises(FrozenInstanceError):
s.y = 5
self.assertEqual(s.x, 3)
self.assertEqual(s.y, 10)
self.assertEqual(s.cached, True)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
If a non-dataclass inherits from a frozen dataclass, allow attributes to be
added to the derived class. Only attributes from from the frozen dataclass
cannot be assigned to. Require all dataclasses in a hierarchy to be either
all frozen or all non-frozen.