Skip to content

Commit 38c03e6

Browse files
committed
[FIX] core: cursor hooks API and implementation
Python 3.8 changed the equality rules for bound methods to be based on the *identity* of the receiver (`__self__`) rather than its *equality*. This means that in 3.7, methods from different instances will compare (and hash) equal, thereby landing in the same map "slot", but that isn't the case in 3.8. While it's usually not relevant, it's an issue for `GroupCalls` which is indexed by a function: in 3.7, that being a method from recordsets comparing equal will deduplicate them, but not anymore in 3.8, leading to duplicated callbacks (exactly the thing GroupCalls aims to avoid). Also, the API of `GroupCalls` turned out to be unusual and weird. The bug above is fixed by using a plain list for callbacks, thereby avoiding comparisons between registered functions. The API is now: callbacks.add(func) # add func to callbacks callbacks.run() # run all callbacks in addition order callbacks.clear() # remove all callbacks In order to handle aggregated data, the `callbacks` object provides a dictionary `callbacks.data` that any callback function can freely use. For the sake of consistency, the `callbacks.data` dict is automatically cleared upon execution of callbacks. Discovered by @william-andre Related to odoo#56583 References: * https://bugs.python.org/issue1617161 * python/cpython#7848 * https://docs.python.org/3/whatsnew/changelog.html#python-3-8-0-alpha-1 (no direct link because individual entries are not linkable, look for bpo-1617161)
1 parent da08955 commit 38c03e6

File tree

7 files changed

+116
-93
lines changed

7 files changed

+116
-93
lines changed

addons/mail/models/mail_thread.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def _prepare_tracking(self, fields):
430430
fnames = self._get_tracked_fields().intersection(fields)
431431
if not fnames:
432432
return
433-
func = self.browse()._finalize_tracking
434-
[initial_values] = self.env.cr.precommit.add(func, dict)
433+
self.env.cr.precommit.add(self._finalize_tracking)
434+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
435435
for record in self:
436436
if not record.id:
437437
continue
@@ -444,16 +444,17 @@ def _discard_tracking(self):
444444
""" Prevent any tracking of fields on ``self``. """
445445
if not self._get_tracked_fields():
446446
return
447-
func = self.browse()._finalize_tracking
448-
[initial_values] = self.env.cr.precommit.add(func, dict)
447+
self.env.cr.precommit.add(self._finalize_tracking)
448+
initial_values = self.env.cr.precommit.data.setdefault(f'mail.tracking.{self._name}', {})
449449
# disable tracking by setting initial values to None
450450
for id_ in self.ids:
451451
initial_values[id_] = None
452452

453-
def _finalize_tracking(self, initial_values):
453+
def _finalize_tracking(self):
454454
""" Generate the tracking messages for the records that have been
455455
prepared with ``_prepare_tracking``.
456456
"""
457+
initial_values = self.env.cr.precommit.data.pop(f'mail.tracking.{self._name}', {})
457458
ids = [id_ for id_, vals in initial_values.items() if vals]
458459
if not ids:
459460
return

addons/test_mail/tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _create_template(cls, model, template_values=None):
5454
def flush_tracking(self):
5555
""" Force the creation of tracking values. """
5656
self.env['base'].flush()
57-
self.cr.precommit()
57+
self.cr.precommit.run()
5858

5959

6060
class TestMailMultiCompanyCommon(TestMailCommon):

odoo/addons/base/tests/test_misc.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -272,36 +272,54 @@ def test_01_code_and_format(self):
272272
self.assertEqual(misc.format_time(lang.with_context(lang='zh_CN').env, time_part, time_format='medium', lang_code='fr_FR'), '16:30:22')
273273

274274

275-
class TestGroupCalls(BaseCase):
276-
def test_callbacks(self):
275+
class TestCallbacks(BaseCase):
276+
def test_callback(self):
277277
log = []
278+
callbacks = misc.Callbacks()
278279

280+
# add foo
279281
def foo():
280282
log.append("foo")
281283

282-
def bar(items):
283-
log.extend(items)
284-
callbacks.add(baz)
284+
callbacks.add(foo)
285285

286-
def baz():
287-
log.append("baz")
286+
# add bar
287+
@callbacks.add
288+
def bar():
289+
log.append("bar")
288290

289-
callbacks = misc.GroupCalls()
291+
# add foo again
290292
callbacks.add(foo)
291-
callbacks.add(bar, list)[0].append(1)
292-
callbacks.add(bar, list)[0].append(2)
293-
self.assertEqual(log, [])
294293

295-
callbacks()
296-
self.assertEqual(log, ["foo", 1, 2, "baz"])
294+
# this should call foo(), bar(), foo()
295+
callbacks.run()
296+
self.assertEqual(log, ["foo", "bar", "foo"])
297+
298+
# this should do nothing
299+
callbacks.run()
300+
self.assertEqual(log, ["foo", "bar", "foo"])
301+
302+
def test_aggregate(self):
303+
log = []
304+
callbacks = misc.Callbacks()
305+
306+
# register foo once
307+
@callbacks.add
308+
def foo():
309+
log.append(callbacks.data["foo"])
310+
311+
# aggregate data
312+
callbacks.data.setdefault("foo", []).append(1)
313+
callbacks.data.setdefault("foo", []).append(2)
314+
callbacks.data.setdefault("foo", []).append(3)
297315

298-
callbacks()
299-
self.assertEqual(log, ["foo", 1, 2, "baz"])
316+
# foo() is called once
317+
callbacks.run()
318+
self.assertEqual(log, [[1, 2, 3]])
319+
self.assertFalse(callbacks.data)
300320

301-
callbacks.add(bar, list)[0].append(3)
302-
callbacks.clear()
303-
callbacks()
304-
self.assertEqual(log, ["foo", 1, 2, "baz"])
321+
callbacks.run()
322+
self.assertEqual(log, [[1, 2, 3]])
305323

306324

307325
class TestRemoveAccents(BaseCase):

odoo/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def checked_call(___dbname, *a, **kw):
352352
# flush here to avoid triggering a serialization error outside
353353
# of this context, which would not retry the call
354354
flush_env(self._cr)
355-
self._cr.precommit()
355+
self._cr.precommit.run()
356356
return result
357357

358358
if self.db:

odoo/sql_db.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,11 @@ def check(f, self, *args, **kwargs):
9696

9797

9898
class BaseCursor:
99-
""" Base class for cursors that manages pre/post commit/rollback hooks. """
99+
""" Base class for cursors that manage pre/post commit hooks. """
100100

101101
def __init__(self):
102-
self.precommit = tools.GroupCalls()
103-
self.postcommit = tools.GroupCalls()
104-
self.prerollback = tools.GroupCalls()
105-
self.postrollback = tools.GroupCalls()
102+
self.precommit = tools.Callbacks()
103+
self.postcommit = tools.Callbacks()
106104

107105
@contextmanager
108106
@check
@@ -111,20 +109,17 @@ def savepoint(self, flush=True):
111109
name = uuid.uuid1().hex
112110
if flush:
113111
flush_env(self, clear=False)
114-
self.precommit()
115-
self.prerollback.clear()
112+
self.precommit.run()
116113
self.execute('SAVEPOINT "%s"' % name)
117114
try:
118115
yield
119116
if flush:
120117
flush_env(self, clear=False)
121-
self.precommit()
122-
self.prerollback.clear()
118+
self.precommit.run()
123119
except Exception:
124120
if flush:
125121
clear_env(self)
126122
self.precommit.clear()
127-
self.prerollback()
128123
self.execute('ROLLBACK TO SAVEPOINT "%s"' % name)
129124
raise
130125
else:
@@ -428,17 +423,15 @@ def after(self, event, func):
428423
if event == 'commit':
429424
self.postcommit.add(func)
430425
elif event == 'rollback':
431-
self.postrollback.add(func)
426+
raise NotImplementedError()
432427

433428
@check
434429
def commit(self):
435430
""" Perform an SQL `COMMIT` """
436431
flush_env(self)
437-
self.precommit()
432+
self.precommit.run()
438433
result = self._cnx.commit()
439-
self.prerollback.clear()
440-
self.postrollback.clear()
441-
self.postcommit()
434+
self.postcommit.run()
442435
return result
443436

444437
@check
@@ -447,9 +440,7 @@ def rollback(self):
447440
clear_env(self)
448441
self.precommit.clear()
449442
self.postcommit.clear()
450-
self.prerollback()
451443
result = self._cnx.rollback()
452-
self.postrollback()
453444
return result
454445

455446
@check
@@ -506,23 +497,18 @@ def autocommit(self, on):
506497
def commit(self):
507498
""" Perform an SQL `COMMIT` """
508499
flush_env(self)
509-
self.precommit()
500+
self.precommit.run()
510501
self._cursor.execute('SAVEPOINT "%s"' % self._savepoint)
511-
self.prerollback.clear()
512-
# ignore post-commit/rollback hooks
502+
# ignore post-commit hooks
513503
self.postcommit.clear()
514-
self.postrollback.clear()
515504

516505
@check
517506
def rollback(self):
518507
""" Perform an SQL `ROLLBACK` """
519508
clear_env(self)
520509
self.precommit.clear()
521-
self.prerollback()
522510
self._cursor.execute('ROLLBACK TO SAVEPOINT "%s"' % self._savepoint)
523-
# ignore post-commit/rollback hooks
524511
self.postcommit.clear()
525-
self.postrollback.clear()
526512

527513
def __getattr__(self, name):
528514
value = getattr(self._cursor, name)

odoo/tests/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,12 @@ def assertQueryCount(self, default=0, flush=True, **counters):
375375
expected = counters.get(login, default)
376376
if flush:
377377
self.env.user.flush()
378-
self.env.cr.precommit()
378+
self.env.cr.precommit.run()
379379
count0 = self.cr.sql_log_count
380380
yield
381381
if flush:
382382
self.env.user.flush()
383-
self.env.cr.precommit()
383+
self.env.cr.precommit.run()
384384
count = self.cr.sql_log_count - count0
385385
if count != expected:
386386
# add some info on caller to allow semi-automatic update of query count
@@ -399,11 +399,11 @@ def assertQueryCount(self, default=0, flush=True, **counters):
399399
# same operations, otherwise the caches might not be ready!
400400
if flush:
401401
self.env.user.flush()
402-
self.env.cr.precommit()
402+
self.env.cr.precommit.run()
403403
yield
404404
if flush:
405405
self.env.user.flush()
406-
self.env.cr.precommit()
406+
self.env.cr.precommit.run()
407407

408408
def assertRecordValues(self, records, expected_values):
409409
''' Compare a recordset with a list of dictionaries representing the expected results.

odoo/tools/misc.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,54 +1082,72 @@ def add(self, elem):
10821082
OrderedSet.add(self, elem)
10831083

10841084

1085-
class GroupCalls:
1086-
""" A collection of callbacks with support for aggregated arguments. Upon
1087-
call, every registered function is called once with positional arguments.
1088-
When registering a function, a tuple of positional arguments is returned, so
1089-
that the caller can modify the arguments in place. This allows to
1090-
accumulate some data to process once::
1085+
class Callbacks:
1086+
""" A simple queue of callback functions. Upon run, every function is
1087+
called (in addition order), and the queue is emptied.
10911088
1092-
callbacks = GroupCalls()
1089+
callbacks = Callbacks()
10931090
1094-
# register print (by default with a list)
1095-
[args] = callbacks.register(print, list)
1096-
args.append(42)
1091+
# add foo
1092+
def foo():
1093+
print("foo")
10971094
1098-
# add an element to the list to print
1099-
[args] = callbacks.register(print, list)
1100-
args.append(43)
1095+
callbacks.add(foo)
11011096
1102-
# print "[42, 43]"
1103-
callbacks()
1097+
# add bar
1098+
callbacks.add
1099+
def bar():
1100+
print("bar")
1101+
1102+
# add foo again
1103+
callbacks.add(foo)
1104+
1105+
# call foo(), bar(), foo(), then clear the callback queue
1106+
callbacks.run()
1107+
1108+
The queue also provides a ``data`` dictionary, that may be freely used to
1109+
store anything, but is mostly aimed at aggregating data for callbacks. The
1110+
dictionary is automatically cleared by ``run()`` once all callback functions
1111+
have been called.
1112+
1113+
# register foo to process aggregated data
1114+
@callbacks.add
1115+
def foo():
1116+
print(sum(callbacks.data['foo']))
1117+
1118+
callbacks.data.setdefault('foo', []).append(1)
1119+
...
1120+
callbacks.data.setdefault('foo', []).append(2)
1121+
...
1122+
callbacks.data.setdefault('foo', []).append(3)
1123+
1124+
# call foo(), which prints 6
1125+
callbacks.run()
1126+
1127+
Given the global nature of ``data``, the keys should identify in a unique
1128+
way the data being stored. It is recommended to use strings with a
1129+
structure like ``"{module}.{feature}"``.
11041130
"""
1131+
__slots__ = ['_funcs', 'data']
1132+
11051133
def __init__(self):
1106-
self._func_args = {} # {func: args}
1134+
self._funcs = []
1135+
self.data = {}
11071136

1108-
def __call__(self):
1109-
""" Call all the registered functions (in first addition order) with
1110-
their respective arguments. Only recurrent functions remain registered
1111-
after the call.
1112-
"""
1113-
func_args = self._func_args
1114-
while func_args:
1115-
func = next(iter(func_args))
1116-
args = func_args.pop(func)
1117-
func(*args)
1118-
1119-
def add(self, func, *types):
1120-
""" Register the given function, and return the tuple of positional
1121-
arguments to call the function with. If the function is not registered
1122-
yet, the list of arguments is made up by invoking the given types.
1123-
"""
1124-
try:
1125-
return self._func_args[func]
1126-
except KeyError:
1127-
args = self._func_args[func] = [type_() for type_ in types]
1128-
return args
1137+
def add(self, func):
1138+
""" Add the given function. """
1139+
self._funcs.append(func)
1140+
1141+
def run(self):
1142+
""" Call all the functions (in addition order), then clear. """
1143+
for func in self._funcs:
1144+
func()
1145+
self.clear()
11291146

11301147
def clear(self):
1131-
""" Remove all callbacks from self. """
1132-
self._func_args.clear()
1148+
""" Remove all callbacks and data from self. """
1149+
self._funcs.clear()
1150+
self.data.clear()
11331151

11341152

11351153
class IterableGenerator:

0 commit comments

Comments
 (0)