Skip to content

Commit 9165d41

Browse files
committed
feat(common): better inheritance support for Slotted and FrozenSlotted
1 parent 2e3a5a0 commit 9165d41

File tree

3 files changed

+77
-29
lines changed

3 files changed

+77
-29
lines changed

ibis/common/annotations.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class Annotation(Slotted, Immutable):
100100
Annotations are used to mark fields in a class and to validate them.
101101
"""
102102

103-
__slots__ = ()
103+
__slots__ = ("pattern", "default")
104+
pattern: Pattern
105+
default: AnyType
104106

105107
def validate(self, name: str, value: AnyType, this: AnyType) -> AnyType:
106108
"""Validate the field.
@@ -142,10 +144,6 @@ class Attribute(Annotation):
142144
Callable to compute the default value of the field.
143145
"""
144146

145-
__slots__ = ("pattern", "default")
146-
pattern: Pattern
147-
default: AnyType
148-
149147
def __init__(self, pattern: Pattern = _any, default: AnyType = EMPTY):
150148
super().__init__(pattern=ensure_pattern(pattern), default=default)
151149

@@ -199,9 +197,7 @@ class Argument(Annotation):
199197
Defaults to positional or keyword.
200198
"""
201199

202-
__slots__ = ("pattern", "default", "typehint", "kind")
203-
pattern: Pattern
204-
default: AnyType
200+
__slots__ = ("typehint", "kind")
205201
typehint: AnyType
206202
kind: int
207203

ibis/common/bases.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class Final(Abstract):
116116
"""Prohibit subclassing."""
117117

118118
def __init_subclass__(cls, **kwargs):
119+
super().__init_subclass__(**kwargs)
119120
cls.__init_subclass__ = cls.__prohibit_inheritance__
120121

121122
@classmethod
@@ -178,37 +179,45 @@ def __cached_equals__(self, other) -> bool:
178179
return result
179180

180181

181-
class Slotted(Abstract):
182+
class SlottedMeta(AbstractMeta):
183+
def __new__(metacls, clsname, bases, dct, **kwargs):
184+
fields = dct.get("__fields__", dct.get("__slots__", ()))
185+
inherited = (getattr(base, "__fields__", ()) for base in bases)
186+
dct["__fields__"] = sum(inherited, ()) + fields
187+
return super().__new__(metacls, clsname, bases, dct, **kwargs)
188+
189+
190+
class Slotted(Abstract, metaclass=SlottedMeta):
182191
"""A lightweight alternative to `ibis.common.grounds.Annotable`.
183192
184193
The class is mostly used to reduce boilerplate code.
185194
"""
186195

187196
def __init__(self, **kwargs) -> None:
188-
for name, value in kwargs.items():
189-
object.__setattr__(self, name, value)
197+
for field in self.__fields__:
198+
object.__setattr__(self, field, kwargs[field])
190199

191200
def __eq__(self, other) -> bool:
192201
if self is other:
193202
return True
194203
if type(self) is not type(other):
195204
return NotImplemented
196-
return all(getattr(self, n) == getattr(other, n) for n in self.__slots__)
205+
return all(getattr(self, n) == getattr(other, n) for n in self.__fields__)
197206

198207
def __getstate__(self):
199-
return {k: getattr(self, k) for k in self.__slots__}
208+
return {k: getattr(self, k) for k in self.__fields__}
200209

201210
def __setstate__(self, state):
202211
for name, value in state.items():
203212
object.__setattr__(self, name, value)
204213

205214
def __repr__(self):
206-
fields = {k: getattr(self, k) for k in self.__slots__}
215+
fields = {k: getattr(self, k) for k in self.__fields__}
207216
fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items())
208217
return f"{self.__class__.__name__}({fieldstring})"
209218

210219
def __rich_repr__(self):
211-
for name in self.__slots__:
220+
for name in self.__fields__:
212221
yield name, getattr(self, name)
213222

214223

@@ -220,18 +229,21 @@ class FrozenSlotted(Slotted, Immutable, Hashable):
220229
"""
221230

222231
__slots__ = ("__precomputed_hash__",)
232+
__fields__ = ()
223233
__precomputed_hash__: int
224234

225235
def __init__(self, **kwargs) -> None:
226-
for name, value in kwargs.items():
227-
object.__setattr__(self, name, value)
228-
hashvalue = hash(tuple(kwargs.values()))
236+
values = []
237+
for field in self.__fields__:
238+
values.append(value := kwargs[field])
239+
object.__setattr__(self, field, value)
240+
hashvalue = hash((self.__class__, tuple(values)))
229241
object.__setattr__(self, "__precomputed_hash__", hashvalue)
230242

231243
def __setstate__(self, state):
232244
for name, value in state.items():
233245
object.__setattr__(self, name, value)
234-
hashvalue = hash(tuple(state.values()))
246+
hashvalue = hash((self.__class__, tuple(state.values())))
235247
object.__setattr__(self, "__precomputed_hash__", hashvalue)
236248

237249
def __hash__(self) -> int:

ibis/common/tests/test_bases.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,50 +266,90 @@ class B(A):
266266
class MyObj(Slotted):
267267
__slots__ = ("a", "b")
268268

269-
def __init__(self, a, b):
270-
super().__init__(a=a, b=b)
271-
272269

273270
def test_slotted():
274-
obj = MyObj(1, 2)
271+
obj = MyObj(a=1, b=2)
275272
assert obj.a == 1
276273
assert obj.b == 2
274+
assert obj.__fields__ == ("a", "b")
277275
assert obj.__slots__ == ("a", "b")
278276
with pytest.raises(AttributeError):
279277
obj.c = 3
280278

281-
obj2 = MyObj(1, 2)
279+
obj2 = MyObj(a=1, b=2)
282280
assert obj == obj2
283281
assert obj is not obj2
284282

285-
obj3 = MyObj(1, 3)
283+
obj3 = MyObj(a=1, b=3)
286284
assert obj != obj3
287285

288286
assert pickle.loads(pickle.dumps(obj)) == obj
289287

288+
with pytest.raises(KeyError):
289+
MyObj(a=1)
290+
291+
292+
class MyObj2(MyObj):
293+
__slots__ = ("c",)
294+
295+
296+
def test_slotted_inheritance():
297+
obj = MyObj2(a=1, b=2, c=3)
298+
assert obj.a == 1
299+
assert obj.b == 2
300+
assert obj.c == 3
301+
assert obj.__fields__ == ("a", "b", "c")
302+
assert obj.__slots__ == ("c",)
303+
with pytest.raises(AttributeError):
304+
obj.d = 4
305+
306+
obj2 = MyObj2(a=1, b=2, c=3)
307+
assert obj == obj2
308+
assert obj is not obj2
309+
310+
obj3 = MyObj2(a=1, b=2, c=4)
311+
assert obj != obj3
312+
assert pickle.loads(pickle.dumps(obj)) == obj
313+
314+
with pytest.raises(KeyError):
315+
MyObj2(a=1, b=2)
316+
290317

291318
class MyFrozenObj(FrozenSlotted):
292319
__slots__ = ("a", "b")
293320

294-
def __init__(self, a, b):
295-
super().__init__(a=a, b=b)
321+
322+
class MyFrozenObj2(MyFrozenObj):
323+
__slots__ = ("c", "d")
296324

297325

298326
def test_frozen_slotted():
299-
obj = MyFrozenObj(1, 2)
327+
obj = MyFrozenObj(a=1, b=2)
328+
300329
assert obj.a == 1
301330
assert obj.b == 2
331+
assert obj.__fields__ == ("a", "b")
302332
assert obj.__slots__ == ("a", "b")
303333
with pytest.raises(AttributeError):
304334
obj.b = 3
305335
with pytest.raises(AttributeError):
306336
obj.c = 3
307337

308-
obj2 = MyFrozenObj(1, 2)
338+
obj2 = MyFrozenObj(a=1, b=2)
309339
assert obj == obj2
310340
assert obj is not obj2
311341
assert hash(obj) == hash(obj2)
312342

313343
restored = pickle.loads(pickle.dumps(obj))
314344
assert restored == obj
315345
assert hash(restored) == hash(obj)
346+
347+
with pytest.raises(KeyError):
348+
MyFrozenObj(a=1)
349+
350+
351+
def test_frozen_slotted_inheritance():
352+
obj3 = MyFrozenObj2(a=1, b=2, c=3, d=4)
353+
assert obj3.__slots__ == ("c", "d")
354+
assert obj3.__fields__ == ("a", "b", "c", "d")
355+
assert pickle.loads(pickle.dumps(obj3)) == obj3

0 commit comments

Comments
 (0)