Skip to content

Commit aff5e12

Browse files
committed
[FIX] core: cursor hooks API and implementation
Python 3.8 changed the equality rules for bound methods to be based on the *identity* of the receiver (`__self__`) rather than its *equality*. This means that in 3.7, methods from different instances will compare (and hash) equal, thereby landing in the same map "slot", but that isn't the case in 3.8. While it's usually not relevant, it's an issue for `GroupCalls` which is indexed by a function: in 3.7, that being a method from recordsets comparing equal will deduplicate them, but not anymore in 3.8, leading to duplicated callbacks (exactly the thing GroupCalls aims to avoid). Also, the API of `GroupCalls` turned out to be unusual and weird. The bug above is fixed by using a plain list for callbacks, thereby avoiding comparisons between registered functions. The API is now: callbacks.add(func) # add func to callbacks callbacks.run() # run all callbacks in addition order callbacks.clear() # remove all callbacks In order to handle aggregated data, the `callbacks` object provides a dictionary `callbacks.data` that any callback function can freely use. For the sake of consistency, the `callbacks.data` dict is automatically cleared upon execution of callbacks. Discovered by @william-andre Related to odoo#56583 References: * https://bugs.python.org/issue1617161 * python/cpython#7848 * https://docs.python.org/3/whatsnew/changelog.html#python-3-8-0-alpha-1 (no direct link because individual entries are not linkable, look for bpo-1617161) X-original-commit: d4b2e92
1 parent 9eee59b commit aff5e12

File tree

7 files changed

+118
-95
lines changed

7 files changed

+118
-95
lines changed

addons/mail/models/mail_thread.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ def _prepare_tracking(self, fields):
503503
fnames = self._get_tracked_fields().intersection(fields)
504504
if not fnames:
505505
return
506-
func = self.browse()._finalize_tracking
507-
[initial_values] = self.env.cr.precommit.add(func, dict)
506+
self.env.cr.precommit.add(self._finalize_tracking)
507+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
508508
for record in self:
509509
if not record.id:
510510
continue
@@ -517,16 +517,17 @@ def _discard_tracking(self):
517517
""" Prevent any tracking of fields on ``self``. """
518518
if not self._get_tracked_fields():
519519
return
520-
func = self.browse()._finalize_tracking
521-
[initial_values] = self.env.cr.precommit.add(func, dict)
520+
self.env.cr.precommit.add(self._finalize_tracking)
521+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
522522
# disable tracking by setting initial values to None
523523
for id_ in self.ids:
524524
initial_values[id_] = None
525525

526-
def _finalize_tracking(self, initial_values):
526+
def _finalize_tracking(self):
527527
""" Generate the tracking messages for the records that have been
528528
prepared with ``_prepare_tracking``.
529529
"""
530+
initial_values = self.env.cr.precommit.data.pop(f'mail.tracking.{self._name}', {})
530531
ids = [id_ for id_, vals in initial_values.items() if vals]
531532
if not ids:
532533
return

addons/mail/tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _reset_mail_context(cls, record):
334334
def flush_tracking(self):
335335
""" Force the creation of tracking values. """
336336
self.env['base'].flush()
337-
self.cr.precommit()
337+
self.cr.precommit.run()
338338

339339
# ------------------------------------------------------------
340340
# MAIL MOCKS

odoo/addons/base/tests/test_misc.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -272,36 +272,54 @@ def test_01_code_and_format(self):
272272
self.assertEqual(misc.format_time(lang.with_context(lang='zh_CN').env, time_part, time_format='medium', lang_code='fr_FR'), '16:30:22')
273273

274274

275-
class TestGroupCalls(BaseCase):
276-
def test_callbacks(self):
275+
class TestCallbacks(BaseCase):
276+
def test_callback(self):
277277
log = []
278+
callbacks = misc.Callbacks()
278279

280+
# add foo
279281
def foo():
280282
log.append("foo")
281283

282-
def bar(items):
283-
log.extend(items)
284-
callbacks.add(baz)
284+
callbacks.add(foo)
285285

286-
def baz():
287-
log.append("baz")
286+
# add bar
287+
@callbacks.add
288+
def bar():
289+
log.append("bar")
288290

289-
callbacks = misc.GroupCalls()
291+
# add foo again
290292
callbacks.add(foo)
291-
callbacks.add(bar, list)[0].append(1)
292-
callbacks.add(bar, list)[0].append(2)
293-
self.assertEqual(log, [])
294293

295-
callbacks()
296-
self.assertEqual(log, ["foo", 1, 2, "baz"])
294+
# this should call foo(), bar(), foo()
295+
callbacks.run()
296+
self.assertEqual(log, ["foo", "bar", "foo"])
297+
298+
# this should do nothing
299+
callbacks.run()
300+
self.assertEqual(log, ["foo", "bar", "foo"])
301+
302+
def test_aggregate(self):
303+
log = []
304+
callbacks = misc.Callbacks()
305+
306+
# register foo once
307+
@callbacks.add
308+
def foo():
309+
log.append(callbacks.data["foo"])
310+
311+
# aggregate data
312+
callbacks.data.setdefault("foo", []).append(1)
313+
callbacks.data.setdefault("foo", []).append(2)
314+
callbacks.data.setdefault("foo", []).append(3)
297315

298-
callbacks()
299-
self.assertEqual(log, ["foo", 1, 2, "baz"])
316+
# foo() is called once
317+
callbacks.run()
318+
self.assertEqual(log, [[1, 2, 3]])
319+
self.assertFalse(callbacks.data)
300320

301-
callbacks.add(bar, list)[0].append(3)
302-
callbacks.clear()
303-
callbacks()
304-
self.assertEqual(log, ["foo", 1, 2, "baz"])
321+
callbacks.run()
322+
self.assertEqual(log, [[1, 2, 3]])
305323

306324

307325
class TestRemoveAccents(BaseCase):

odoo/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def checked_call(___dbname, *a, **kw):
351351
# flush here to avoid triggering a serialization error outside
352352
# of this context, which would not retry the call
353353
flush_env(self._cr)
354-
self._cr.precommit()
354+
self._cr.precommit.run()
355355
return result
356356

357357
if self.db:

odoo/sql_db.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ def check(f, self, *args, **kwargs):
9696

9797

9898
class BaseCursor:
99-
""" Base class for cursors that manages pre/post commit/rollback hooks. """
99+
""" Base class for cursors that manage pre/post commit hooks. """
100100

101101
def __init__(self):
102-
self.precommit = tools.GroupCalls()
103-
self.postcommit = tools.GroupCalls()
104-
self.prerollback = tools.GroupCalls()
105-
self.postrollback = tools.GroupCalls()
102+
self.precommit = tools.Callbacks()
103+
self.postcommit = tools.Callbacks()
106104

107105
@contextmanager
108106
@check
@@ -111,20 +109,17 @@ def savepoint(self, flush=True):
111109
name = uuid.uuid1().hex
112110
if flush:
113111
flush_env(self, clear=False)
114-
self.precommit()
115-
self.prerollback.clear()
112+
self.precommit.run()
116113
self.execute('SAVEPOINT "%s"' % name)
117114
try:
118115
yield
119116
if flush:
120117
flush_env(self, clear=False)
121-
self.precommit()
122-
self.prerollback.clear()
118+
self.precommit.run()
123119
except Exception:
124120
if flush:
125121
clear_env(self)
126122
self.precommit.clear()
127-
self.prerollback()
128123
self.execute('ROLLBACK TO SAVEPOINT "%s"' % name)
129124
raise
130125
else:
@@ -428,17 +423,15 @@ def after(self, event, func):
428423
if event == 'commit':
429424
self.postcommit.add(func)
430425
elif event == 'rollback':
431-
self.postrollback.add(func)
426+
raise NotImplementedError()
432427

433428
@check
434429
def commit(self):
435430
""" Perform an SQL `COMMIT` """
436431
flush_env(self)
437-
self.precommit()
432+
self.precommit.run()
438433
result = self._cnx.commit()
439-
self.prerollback.clear()
440-
self.postrollback.clear()
441-
self.postcommit()
434+
self.postcommit.run()
442435
return result
443436

444437
@check
@@ -447,9 +440,7 @@ def rollback(self):
447440
clear_env(self)
448441
self.precommit.clear()
449442
self.postcommit.clear()
450-
self.prerollback()
451443
result = self._cnx.rollback()
452-
self.postrollback()
453444
return result
454445

455446
@check
@@ -506,23 +497,18 @@ def autocommit(self, on):
506497
def commit(self):
507498
""" Perform an SQL `COMMIT` """
508499
flush_env(self)
509-
self.precommit()
500+
self.precommit.run()
510501
self._cursor.execute('SAVEPOINT "%s"' % self._savepoint)
511-
self.prerollback.clear()
512-
# ignore post-commit/rollback hooks
502+
# ignore post-commit hooks
513503
self.postcommit.clear()
514-
self.postrollback.clear()
515504

516505
@check
517506
def rollback(self):
518507
""" Perform an SQL `ROLLBACK` """
519508
clear_env(self)
520509
self.precommit.clear()
521-
self.prerollback()
522510
self._cursor.execute('ROLLBACK TO SAVEPOINT "%s"' % self._savepoint)
523-
# ignore post-commit/rollback hooks
524511
self.postcommit.clear()
525-
self.postrollback.clear()
526512

527513
def __getattr__(self, name):
528514
value = getattr(self._cursor, name)

odoo/tests/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,14 +398,14 @@ def get_unaccent_wrapper(cr):
398398

399399
if flush:
400400
self.env.user.flush()
401-
self.env.cr.precommit()
401+
self.env.cr.precommit.run()
402402

403403
with patch('odoo.sql_db.Cursor.execute', execute):
404404
with patch('odoo.osv.expression.get_unaccent_wrapper', get_unaccent_wrapper):
405405
yield actual_queries
406406
if flush:
407407
self.env.user.flush()
408-
self.env.cr.precommit()
408+
self.env.cr.precommit.run()
409409

410410
self.assertEqual(
411411
len(actual_queries), len(expected),
@@ -438,12 +438,12 @@ def assertQueryCount(self, default=0, flush=True, **counters):
438438
expected = counters.get(login, default)
439439
if flush:
440440
self.env.user.flush()
441-
self.env.cr.precommit()
441+
self.env.cr.precommit.run()
442442
count0 = self.cr.sql_log_count
443443
yield
444444
if flush:
445445
self.env.user.flush()
446-
self.env.cr.precommit()
446+
self.env.cr.precommit.run()
447447
count = self.cr.sql_log_count - count0
448448
if count != expected:
449449
# add some info on caller to allow semi-automatic update of query count
@@ -462,11 +462,11 @@ def assertQueryCount(self, default=0, flush=True, **counters):
462462
# same operations, otherwise the caches might not be ready!
463463
if flush:
464464
self.env.user.flush()
465-
self.env.cr.precommit()
465+
self.env.cr.precommit.run()
466466
yield
467467
if flush:
468468
self.env.user.flush()
469-
self.env.cr.precommit()
469+
self.env.cr.precommit.run()
470470

471471
def assertRecordValues(self, records, expected_values):
472472
''' Compare a recordset with a list of dictionaries representing the expected results.

odoo/tools/misc.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,54 +1084,72 @@ def add(self, elem):
10841084
OrderedSet.add(self, elem)
10851085

10861086

1087-
class GroupCalls:
1088-
""" A collection of callbacks with support for aggregated arguments. Upon
1089-
call, every registered function is called once with positional arguments.
1090-
When registering a function, a tuple of positional arguments is returned, so
1091-
that the caller can modify the arguments in place. This allows to
1092-
accumulate some data to process once::
1087+
class Callbacks:
1088+
""" A simple queue of callback functions. Upon run, every function is
1089+
called (in addition order), and the queue is emptied.
10931090
1094-
callbacks = GroupCalls()
1091+
callbacks = Callbacks()
10951092
1096-
# register print (by default with a list)
1097-
[args] = callbacks.register(print, list)
1098-
args.append(42)
1093+
# add foo
1094+
def foo():
1095+
print("foo")
10991096
1100-
# add an element to the list to print
1101-
[args] = callbacks.register(print, list)
1102-
args.append(43)
1097+
callbacks.add(foo)
11031098
1104-
# print "[42, 43]"
1105-
callbacks()
1099+
# add bar
1100+
callbacks.add
1101+
def bar():
1102+
print("bar")
1103+
1104+
# add foo again
1105+
callbacks.add(foo)
1106+
1107+
# call foo(), bar(), foo(), then clear the callback queue
1108+
callbacks.run()
1109+
1110+
The queue also provides a ``data`` dictionary, that may be freely used to
1111+
store anything, but is mostly aimed at aggregating data for callbacks. The
1112+
dictionary is automatically cleared by ``run()`` once all callback functions
1113+
have been called.
1114+
1115+
# register foo to process aggregated data
1116+
@callbacks.add
1117+
def foo():
1118+
print(sum(callbacks.data['foo']))
1119+
1120+
callbacks.data.setdefault('foo', []).append(1)
1121+
...
1122+
callbacks.data.setdefault('foo', []).append(2)
1123+
...
1124+
callbacks.data.setdefault('foo', []).append(3)
1125+
1126+
# call foo(), which prints 6
1127+
callbacks.run()
1128+
1129+
Given the global nature of ``data``, the keys should identify in a unique
1130+
way the data being stored. It is recommended to use strings with a
1131+
structure like ``"{module}.{feature}"``.
11061132
"""
1133+
__slots__ = ['_funcs', 'data']
1134+
11071135
def __init__(self):
1108-
self._func_args = {} # {func: args}
1136+
self._funcs = []
1137+
self.data = {}
11091138

1110-
def __call__(self):
1111-
""" Call all the registered functions (in first addition order) with
1112-
their respective arguments. Only recurrent functions remain registered
1113-
after the call.
1114-
"""
1115-
func_args = self._func_args
1116-
while func_args:
1117-
func = next(iter(func_args))
1118-
args = func_args.pop(func)
1119-
func(*args)
1120-
1121-
def add(self, func, *types):
1122-
""" Register the given function, and return the tuple of positional
1123-
arguments to call the function with. If the function is not registered
1124-
yet, the list of arguments is made up by invoking the given types.
1125-
"""
1126-
try:
1127-
return self._func_args[func]
1128-
except KeyError:
1129-
args = self._func_args[func] = [type_() for type_ in types]
1130-
return args
1139+
def add(self, func):
1140+
""" Add the given function. """
1141+
self._funcs.append(func)
1142+
1143+
def run(self):
1144+
""" Call all the functions (in addition order), then clear. """
1145+
for func in self._funcs:
1146+
func()
1147+
self.clear()
11311148

11321149
def clear(self):
1133-
""" Remove all callbacks from self. """
1134-
self._func_args.clear()
1150+
""" Remove all callbacks and data from self. """
1151+
self._funcs.clear()
1152+
self.data.clear()
11351153

11361154

11371155
class IterableGenerator:

0 commit comments

Comments
 (0)