1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from synapse .storage .database import make_tuple_comparison_clause
15+ from typing import Callable , NoReturn , Tuple
16+ from unittest .mock import Mock
17+
18+ from twisted .test .proto_helpers import MemoryReactor
19+
20+ from synapse .server import HomeServer
21+ from synapse .storage .database import (
22+ DatabasePool ,
23+ LoggingTransaction ,
24+ make_tuple_comparison_clause ,
25+ )
1626from synapse .storage .engines import BaseDatabaseEngine
27+ from synapse .util import Clock
1728
1829from tests import unittest
1930
@@ -38,3 +49,103 @@ def test_native_tuple_comparison(self):
3849 clause , args = make_tuple_comparison_clause ([("a" , 1 ), ("b" , 2 )])
3950 self .assertEqual (clause , "(a,b) > (?,?)" )
4051 self .assertEqual (args , [1 , 2 ])
52+
53+
54+ class CallbacksTestCase (unittest .HomeserverTestCase ):
55+ """Tests for transaction callbacks."""
56+
57+ def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) -> None :
58+ self .store = hs .get_datastores ().main
59+ self .db_pool : DatabasePool = self .store .db_pool
60+
61+ def _run_interaction (
62+ self , func : Callable [[LoggingTransaction , int ], None ]
63+ ) -> Tuple [Mock , Mock ]:
64+ """Run the given function in a database transaction, with callbacks registered.
65+
66+ Args:
67+ func: The function to be run in a transaction. The transaction will be
68+ retried if `func` raises an `OperationalError`.
69+
70+ Returns:
71+ Two mocks, which were registered as an `after_callback` and an
72+ `exception_callback` respectively, on every transaction attempt.
73+ """
74+ after_callback = Mock ()
75+ exception_callback = Mock ()
76+
77+ def _test_txn (txn : LoggingTransaction ) -> None :
78+ txn .call_after (after_callback , 123 , 456 , extra = 789 )
79+ txn .call_on_exception (exception_callback , 987 , 654 , extra = 321 )
80+ func (txn )
81+
82+ try :
83+ self .get_success_or_raise (
84+ self .db_pool .runInteraction ("test_transaction" , _test_txn )
85+ )
86+ except Exception :
87+ pass
88+
89+ return after_callback , exception_callback
90+
91+ def test_after_callback (self ) -> None :
92+ """Test that the after callback is called when a transaction succeeds."""
93+ after_callback , exception_callback = self ._run_interaction (lambda txn : None )
94+
95+ after_callback .assert_called_once_with (123 , 456 , extra = 789 )
96+ exception_callback .assert_not_called ()
97+
98+ def test_exception_callback (self ) -> None :
99+ """Test that the exception callback is called when a transaction fails."""
100+ after_callback , exception_callback = self ._run_interaction (lambda txn : 1 / 0 )
101+
102+ after_callback .assert_not_called ()
103+ exception_callback .assert_called_once_with (987 , 654 , extra = 321 )
104+
105+ def test_failed_retry (self ) -> None :
106+ """Test that the exception callback is called for every failed attempt."""
107+
108+ def _test_txn (txn : LoggingTransaction ) -> NoReturn :
109+ """Simulate a retryable failure on every attempt."""
110+ raise self .db_pool .engine .module .OperationalError ()
111+
112+ after_callback , exception_callback = self ._run_interaction (_test_txn )
113+
114+ after_callback .assert_not_called ()
115+ exception_callback .assert_has_calls (
116+ [
117+ ((987 , 654 ), {"extra" : 321 }),
118+ ((987 , 654 ), {"extra" : 321 }),
119+ ((987 , 654 ), {"extra" : 321 }),
120+ ((987 , 654 ), {"extra" : 321 }),
121+ ((987 , 654 ), {"extra" : 321 }),
122+ ((987 , 654 ), {"extra" : 321 }),
123+ ]
124+ )
125+ self .assertEqual (exception_callback .call_count , 6 ) # no additional calls
126+
127+ def test_successful_retry (self ) -> None :
128+ """Test callbacks for a failed transaction followed by a successful attempt."""
129+ first_attempt = True
130+
131+ def _test_txn (txn : LoggingTransaction ) -> None :
132+ """Simulate a retryable failure on the first attempt only."""
133+ nonlocal first_attempt
134+ if first_attempt :
135+ first_attempt = False
136+ raise self .db_pool .engine .module .OperationalError ()
137+ else :
138+ return None
139+
140+ after_callback , exception_callback = self ._run_interaction (_test_txn )
141+
142+ # Calling both `after_callback`s when the first attempt failed is rather
143+ # dubious. But let's document the behaviour in a test.
144+ after_callback .assert_has_calls (
145+ [
146+ ((123 , 456 ), {"extra" : 789 }),
147+ ((123 , 456 ), {"extra" : 789 }),
148+ ]
149+ )
150+ self .assertEqual (after_callback .call_count , 2 ) # no additional calls
151+ exception_callback .assert_not_called ()
0 commit comments