Skip to content

Commit 2e0ca01

Browse files
committed
Improve zero argument support for super() in dataclasses
1 parent 0a3577b commit 2e0ca01

File tree

3 files changed

+94
-13
lines changed

3 files changed

+94
-13
lines changed

Lib/dataclasses.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,11 +1219,6 @@ def _get_slots(cls):
12191219

12201220

12211221
def _update_func_cell_for__class__(f, oldcls, newcls):
1222-
# Returns True if we update a cell, else False.
1223-
if f is None:
1224-
# f will be None in the case of a property where not all of
1225-
# fget, fset, and fdel are used. Nothing to do in that case.
1226-
return False
12271222
try:
12281223
idx = f.__code__.co_freevars.index("__class__")
12291224
except ValueError:
@@ -1232,13 +1227,36 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
12321227
# Fix the cell to point to the new class, if it's already pointing
12331228
# at the old class. I'm not convinced that the "is oldcls" test
12341229
# is needed, but other than performance can't hurt.
1235-
closure = f.__closure__[idx]
1236-
if closure.cell_contents is oldcls:
1237-
closure.cell_contents = newcls
1230+
cell = f.__closure__[idx]
1231+
if cell.cell_contents is oldcls:
1232+
cell.cell_contents = newcls
12381233
return True
12391234
return False
12401235

12411236

1237+
def _find_inner_functions(obj, _seen=None, _depth=0):
1238+
if _seen is None:
1239+
_seen = set()
1240+
if id(obj) in _seen:
1241+
return None
1242+
_seen.add(id(obj))
1243+
1244+
_depth += 1
1245+
if _depth > 2:
1246+
return None
1247+
1248+
obj = inspect.unwrap(obj)
1249+
1250+
for attr in dir(obj):
1251+
value = getattr(obj, attr, None)
1252+
if value is None:
1253+
continue
1254+
if isinstance(obj, types.FunctionType):
1255+
yield obj
1256+
return
1257+
yield from _find_inner_functions(value, _seen, _depth)
1258+
1259+
12421260
def _add_slots(cls, is_frozen, weakref_slot):
12431261
# Need to create a new class, since we can't set __slots__ after a
12441262
# class has been created, and the @dataclass decorator is called
@@ -1297,7 +1315,10 @@ def _add_slots(cls, is_frozen, weakref_slot):
12971315
# (the newly created one, which we're returning) and not the
12981316
# original class. We can break out of this loop as soon as we
12991317
# make an update, since all closures for a class will share a
1300-
# given cell.
1318+
# given cell. First we try to find a pure function/properties,
1319+
# and then fallback to inspecting custom descriptors.
1320+
1321+
custom_descriptors_to_check = []
13011322
for member in newcls.__dict__.values():
13021323
# If this is a wrapped function, unwrap it.
13031324
member = inspect.unwrap(member)
@@ -1306,10 +1327,27 @@ def _add_slots(cls, is_frozen, weakref_slot):
13061327
if _update_func_cell_for__class__(member, cls, newcls):
13071328
break
13081329
elif isinstance(member, property):
1309-
if (_update_func_cell_for__class__(member.fget, cls, newcls)
1310-
or _update_func_cell_for__class__(member.fset, cls, newcls)
1311-
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
1312-
break
1330+
for f in member.fget, member.fset, member.fdel:
1331+
if f is None:
1332+
continue
1333+
# unwrap once more in case function
1334+
# was wrapped before it became property
1335+
f = inspect.unwrap(f)
1336+
if _update_func_cell_for__class__(f, cls, newcls):
1337+
break
1338+
elif hasattr(member, "__get__") and not inspect.ismemberdescriptor(
1339+
member
1340+
):
1341+
# we don't want to inspect custom descriptors just yet
1342+
# there's still a chance we'll encounter a pure function
1343+
# or a property
1344+
custom_descriptors_to_check.append(member)
1345+
else:
1346+
# now let's ensure custom descriptors won't be left out
1347+
for descriptor in custom_descriptors_to_check:
1348+
for f in _find_inner_functions(descriptor):
1349+
if _update_func_cell_for__class__(f, cls, newcls):
1350+
break
13131351

13141352
return newcls
13151353

Lib/test/test_dataclasses/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5012,6 +5012,47 @@ def foo(self):
50125012

50135013
A().foo()
50145014

5015+
def test_wrapped_property(self):
5016+
def mydecorator(f):
5017+
@wraps(f)
5018+
def wrapper(*args, **kwargs):
5019+
return f(*args, **kwargs)
5020+
return wrapper
5021+
5022+
class B:
5023+
@property
5024+
def foo(self):
5025+
return "bar"
5026+
5027+
@dataclass(slots=True)
5028+
class A(B):
5029+
@property
5030+
@mydecorator
5031+
def foo(self):
5032+
return super().foo
5033+
5034+
self.assertEqual(A().foo, "bar")
5035+
5036+
def test_custom_descriptor(self):
5037+
class CustomDescriptor:
5038+
def __init__(self, f):
5039+
self._f = f
5040+
5041+
def __get__(self, instance, owner):
5042+
return self._f(instance)
5043+
5044+
class B:
5045+
def foo(self):
5046+
return "bar"
5047+
5048+
@dataclass(slots=True)
5049+
class A(B):
5050+
@CustomDescriptor
5051+
def foo(cls):
5052+
return super().foo()
5053+
5054+
self.assertEqual(A().foo, "bar")
5055+
50155056
def test_remembered_class(self):
50165057
# Apply the dataclass decorator manually (not when the class
50175058
# is created), so that we can keep a reference to the
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is
2+
specified and custom descriptor is used or `property` function is wrapped.

0 commit comments

Comments
 (0)