Skip to content

Commit 68058da

Browse files
committed
[FIX] core: cursor hooks API and implementation
Python 3.8 changed the equality rules for bound methods to be based on the *identity* of the receiver (`__self__`) rather than its *equality*. This means that in 3.7, methods from different instances will compare (and hash) equal, thereby landing in the same map "slot", but that isn't the case in 3.8. While it's usually not relevant, it's an issue for `GroupCalls` which is indexed by a function: in 3.7, that being a method from recordsets comparing equal will deduplicate them, but not anymore in 3.8, leading to duplicated callbacks (exactly the thing GroupCalls aims to avoid). Also, the API of `GroupCalls` turned out to be unusual and weird. The bug above is fixed by using a plain list for callbacks, thereby avoiding comparisons between registered functions. The API is now: callbacks.add(func) # add func to callbacks callbacks.run() # run all callbacks in addition order callbacks.clear() # remove all callbacks In order to handle aggregated data, the `callbacks` object provides a dictionary `callbacks.data` that any callback function can freely use. For the sake of consistency, the `callbacks.data` dict is automatically cleared upon execution of callbacks. Discovered by @william-andre Related to odoo#56583 References: * https://bugs.python.org/issue1617161 * python/cpython#7848 * https://docs.python.org/3/whatsnew/changelog.html#python-3-8-0-alpha-1 (no direct link because individual entries are not linkable, look for bpo-1617161) X-original-commit: a3a4d14
1 parent 612a8f7 commit 68058da

File tree

7 files changed

+116
-93
lines changed

7 files changed

+116
-93
lines changed

addons/mail/models/mail_thread.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@ def _prepare_tracking(self, fields):
419419
fnames = self._get_tracked_fields().intersection(fields)
420420
if not fnames:
421421
return
422-
func = self.browse()._finalize_tracking
423-
[initial_values] = self.env.cr.precommit.add(func, dict)
422+
self.env.cr.precommit.add(self._finalize_tracking)
423+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
424424
for record in self:
425425
if not record.id:
426426
continue
@@ -433,16 +433,17 @@ def _discard_tracking(self):
433433
""" Prevent any tracking of fields on ``self``. """
434434
if not self._get_tracked_fields():
435435
return
436-
func = self.browse()._finalize_tracking
437-
[initial_values] = self.env.cr.precommit.add(func, dict)
436+
self.env.cr.precommit.add(self._finalize_tracking)
437+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
438438
# disable tracking by setting initial values to None
439439
for id_ in self.ids:
440440
initial_values[id_] = None
441441

442-
def _finalize_tracking(self, initial_values):
442+
def _finalize_tracking(self):
443443
""" Generate the tracking messages for the records that have been
444444
prepared with ``_prepare_tracking``.
445445
"""
446+
initial_values = self.env.cr.precommit.data.pop(f'mail.tracking.{self._name}', {})
446447
ids = [id_ for id_, vals in initial_values.items() if vals]
447448
if not ids:
448449
return

addons/mail/tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _reset_mail_context(cls, record):
325325
def flush_tracking(self):
326326
""" Force the creation of tracking values. """
327327
self.env['base'].flush()
328-
self.cr.precommit()
328+
self.cr.precommit.run()
329329

330330
# ------------------------------------------------------------
331331
# MAIL MOCKS

odoo/addons/base/tests/test_misc.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -272,36 +272,54 @@ def test_01_code_and_format(self):
272272
self.assertEqual(misc.format_time(lang.with_context(lang='zh_CN').env, time_part, time_format='medium', lang_code='fr_FR'), '16:30:22')
273273

274274

275-
class TestGroupCalls(BaseCase):
276-
def test_callbacks(self):
275+
class TestCallbacks(BaseCase):
276+
def test_callback(self):
277277
log = []
278+
callbacks = misc.Callbacks()
278279

280+
# add foo
279281
def foo():
280282
log.append("foo")
281283

282-
def bar(items):
283-
log.extend(items)
284-
callbacks.add(baz)
284+
callbacks.add(foo)
285285

286-
def baz():
287-
log.append("baz")
286+
# add bar
287+
@callbacks.add
288+
def bar():
289+
log.append("bar")
288290

289-
callbacks = misc.GroupCalls()
291+
# add foo again
290292
callbacks.add(foo)
291-
callbacks.add(bar, list)[0].append(1)
292-
callbacks.add(bar, list)[0].append(2)
293-
self.assertEqual(log, [])
294293

295-
callbacks()
296-
self.assertEqual(log, ["foo", 1, 2, "baz"])
294+
# this should call foo(), bar(), foo()
295+
callbacks.run()
296+
self.assertEqual(log, ["foo", "bar", "foo"])
297+
298+
# this should do nothing
299+
callbacks.run()
300+
self.assertEqual(log, ["foo", "bar", "foo"])
301+
302+
def test_aggregate(self):
303+
log = []
304+
callbacks = misc.Callbacks()
305+
306+
# register foo once
307+
@callbacks.add
308+
def foo():
309+
log.append(callbacks.data["foo"])
310+
311+
# aggregate data
312+
callbacks.data.setdefault("foo", []).append(1)
313+
callbacks.data.setdefault("foo", []).append(2)
314+
callbacks.data.setdefault("foo", []).append(3)
297315

298-
callbacks()
299-
self.assertEqual(log, ["foo", 1, 2, "baz"])
316+
# foo() is called once
317+
callbacks.run()
318+
self.assertEqual(log, [[1, 2, 3]])
319+
self.assertFalse(callbacks.data)
300320

301-
callbacks.add(bar, list)[0].append(3)
302-
callbacks.clear()
303-
callbacks()
304-
self.assertEqual(log, ["foo", 1, 2, "baz"])
321+
callbacks.run()
322+
self.assertEqual(log, [[1, 2, 3]])
305323

306324

307325
class TestRemoveAccents(BaseCase):

odoo/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def checked_call(___dbname, *a, **kw):
349349
# flush here to avoid triggering a serialization error outside
350350
# of this context, which would not retry the call
351351
flush_env(self._cr)
352-
self._cr.precommit()
352+
self._cr.precommit.run()
353353
return result
354354

355355
if self.db:

odoo/sql_db.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ def check(f, self, *args, **kwargs):
9696

9797

9898
class BaseCursor:
99-
""" Base class for cursors that manages pre/post commit/rollback hooks. """
99+
""" Base class for cursors that manage pre/post commit hooks. """
100100

101101
def __init__(self):
102-
self.precommit = tools.GroupCalls()
103-
self.postcommit = tools.GroupCalls()
104-
self.prerollback = tools.GroupCalls()
105-
self.postrollback = tools.GroupCalls()
102+
self.precommit = tools.Callbacks()
103+
self.postcommit = tools.Callbacks()
106104

107105
@contextmanager
108106
@check
@@ -111,20 +109,17 @@ def savepoint(self, flush=True):
111109
name = uuid.uuid1().hex
112110
if flush:
113111
flush_env(self, clear=False)
114-
self.precommit()
115-
self.prerollback.clear()
112+
self.precommit.run()
116113
self.execute('SAVEPOINT "%s"' % name)
117114
try:
118115
yield
119116
if flush:
120117
flush_env(self, clear=False)
121-
self.precommit()
122-
self.prerollback.clear()
118+
self.precommit.run()
123119
except Exception:
124120
if flush:
125121
clear_env(self)
126122
self.precommit.clear()
127-
self.prerollback()
128123
self.execute('ROLLBACK TO SAVEPOINT "%s"' % name)
129124
raise
130125
else:
@@ -428,17 +423,15 @@ def after(self, event, func):
428423
if event == 'commit':
429424
self.postcommit.add(func)
430425
elif event == 'rollback':
431-
self.postrollback.add(func)
426+
raise NotImplementedError()
432427

433428
@check
434429
def commit(self):
435430
""" Perform an SQL `COMMIT` """
436431
flush_env(self)
437-
self.precommit()
432+
self.precommit.run()
438433
result = self._cnx.commit()
439-
self.prerollback.clear()
440-
self.postrollback.clear()
441-
self.postcommit()
434+
self.postcommit.run()
442435
return result
443436

444437
@check
@@ -447,9 +440,7 @@ def rollback(self):
447440
clear_env(self)
448441
self.precommit.clear()
449442
self.postcommit.clear()
450-
self.prerollback()
451443
result = self._cnx.rollback()
452-
self.postrollback()
453444
return result
454445

455446
@check
@@ -506,23 +497,18 @@ def autocommit(self, on):
506497
def commit(self):
507498
""" Perform an SQL `COMMIT` """
508499
flush_env(self)
509-
self.precommit()
500+
self.precommit.run()
510501
self._cursor.execute('SAVEPOINT "%s"' % self._savepoint)
511-
self.prerollback.clear()
512-
# ignore post-commit/rollback hooks
502+
# ignore post-commit hooks
513503
self.postcommit.clear()
514-
self.postrollback.clear()
515504

516505
@check
517506
def rollback(self):
518507
""" Perform an SQL `ROLLBACK` """
519508
clear_env(self)
520509
self.precommit.clear()
521-
self.prerollback()
522510
self._cursor.execute('ROLLBACK TO SAVEPOINT "%s"' % self._savepoint)
523-
# ignore post-commit/rollback hooks
524511
self.postcommit.clear()
525-
self.postrollback.clear()
526512

527513
def __getattr__(self, name):
528514
value = getattr(self._cursor, name)

odoo/tests/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,12 @@ def assertQueryCount(self, default=0, flush=True, **counters):
431431
expected = counters.get(login, default)
432432
if flush:
433433
self.env.user.flush()
434-
self.env.cr.precommit()
434+
self.env.cr.precommit.run()
435435
count0 = self.cr.sql_log_count
436436
yield
437437
if flush:
438438
self.env.user.flush()
439-
self.env.cr.precommit()
439+
self.env.cr.precommit.run()
440440
count = self.cr.sql_log_count - count0
441441
if count != expected:
442442
# add some info on caller to allow semi-automatic update of query count
@@ -455,11 +455,11 @@ def assertQueryCount(self, default=0, flush=True, **counters):
455455
# same operations, otherwise the caches might not be ready!
456456
if flush:
457457
self.env.user.flush()
458-
self.env.cr.precommit()
458+
self.env.cr.precommit.run()
459459
yield
460460
if flush:
461461
self.env.user.flush()
462-
self.env.cr.precommit()
462+
self.env.cr.precommit.run()
463463

464464
def assertRecordValues(self, records, expected_values):
465465
''' Compare a recordset with a list of dictionaries representing the expected results.

odoo/tools/misc.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,54 +1084,72 @@ def add(self, elem):
10841084
OrderedSet.add(self, elem)
10851085

10861086

1087-
class GroupCalls:
1088-
""" A collection of callbacks with support for aggregated arguments. Upon
1089-
call, every registered function is called once with positional arguments.
1090-
When registering a function, a tuple of positional arguments is returned, so
1091-
that the caller can modify the arguments in place. This allows to
1092-
accumulate some data to process once::
1087+
class Callbacks:
1088+
""" A simple queue of callback functions. Upon run, every function is
1089+
called (in addition order), and the queue is emptied.
10931090
1094-
callbacks = GroupCalls()
1091+
callbacks = Callbacks()
10951092
1096-
# register print (by default with a list)
1097-
[args] = callbacks.register(print, list)
1098-
args.append(42)
1093+
# add foo
1094+
def foo():
1095+
print("foo")
10991096
1100-
# add an element to the list to print
1101-
[args] = callbacks.register(print, list)
1102-
args.append(43)
1097+
callbacks.add(foo)
11031098
1104-
# print "[42, 43]"
1105-
callbacks()
1099+
# add bar
1100+
callbacks.add
1101+
def bar():
1102+
print("bar")
1103+
1104+
# add foo again
1105+
callbacks.add(foo)
1106+
1107+
# call foo(), bar(), foo(), then clear the callback queue
1108+
callbacks.run()
1109+
1110+
The queue also provides a ``data`` dictionary, that may be freely used to
1111+
store anything, but is mostly aimed at aggregating data for callbacks. The
1112+
dictionary is automatically cleared by ``run()`` once all callback functions
1113+
have been called.
1114+
1115+
# register foo to process aggregated data
1116+
@callbacks.add
1117+
def foo():
1118+
print(sum(callbacks.data['foo']))
1119+
1120+
callbacks.data.setdefault('foo', []).append(1)
1121+
...
1122+
callbacks.data.setdefault('foo', []).append(2)
1123+
...
1124+
callbacks.data.setdefault('foo', []).append(3)
1125+
1126+
# call foo(), which prints 6
1127+
callbacks.run()
1128+
1129+
Given the global nature of ``data``, the keys should identify in a unique
1130+
way the data being stored. It is recommended to use strings with a
1131+
structure like ``"{module}.{feature}"``.
11061132
"""
1133+
__slots__ = ['_funcs', 'data']
1134+
11071135
def __init__(self):
1108-
self._func_args = {} # {func: args}
1136+
self._funcs = []
1137+
self.data = {}
11091138

1110-
def __call__(self):
1111-
""" Call all the registered functions (in first addition order) with
1112-
their respective arguments. Only recurrent functions remain registered
1113-
after the call.
1114-
"""
1115-
func_args = self._func_args
1116-
while func_args:
1117-
func = next(iter(func_args))
1118-
args = func_args.pop(func)
1119-
func(*args)
1120-
1121-
def add(self, func, *types):
1122-
""" Register the given function, and return the tuple of positional
1123-
arguments to call the function with. If the function is not registered
1124-
yet, the list of arguments is made up by invoking the given types.
1125-
"""
1126-
try:
1127-
return self._func_args[func]
1128-
except KeyError:
1129-
args = self._func_args[func] = [type_() for type_ in types]
1130-
return args
1139+
def add(self, func):
1140+
""" Add the given function. """
1141+
self._funcs.append(func)
1142+
1143+
def run(self):
1144+
""" Call all the functions (in addition order), then clear. """
1145+
for func in self._funcs:
1146+
func()
1147+
self.clear()
11311148

11321149
def clear(self):
1133-
""" Remove all callbacks from self. """
1134-
self._func_args.clear()
1150+
""" Remove all callbacks and data from self. """
1151+
self._funcs.clear()
1152+
self.data.clear()
11351153

11361154

11371155
class IterableGenerator:

0 commit comments

Comments
 (0)