@@ -222,3 +222,105 @@ def once(tx):
222222 for failure in failures :
223223 with self .subTest (failure = failure ):
224224 _test ()
225+
226+ def test_reset_fails_after_pull (self ):
227+ def _test (invalid_response_ , api_ ):
228+ def check_exception (exc ):
229+ self .assertEqual (
230+ exc .exception .code ,
231+ "Neo.TransientError.Statement."
232+ "RemoteExecutionTransientError"
233+ )
234+ if self .driver_supports_features (
235+ types .Feature .API_RETRYABLE_EXCEPTION
236+ ):
237+ self .assertTrue (exc .exception .retryable )
238+
239+ def api_call (session_ ):
240+ if api_ == "session" :
241+ with self .assertRaises (types .DriverError ) as exc :
242+ result = session_ .run ("RETURN 1 AS n" )
243+ list (result )
244+ check_exception (exc )
245+ elif api_ == "explicit_tx" :
246+ tx = session_ .begin_transaction ()
247+ try :
248+ with self .assertRaises (types .DriverError ) as exc :
249+ result = tx .run ("RETURN 1 AS n" )
250+ list (result )
251+ check_exception (exc )
252+ finally :
253+ tx .close ()
254+ elif api_ == "managed_tx" :
255+ run = 0
256+
257+ def work (tx ):
258+ nonlocal run
259+ run += 1
260+ if run == 1 :
261+ with self .assertRaises (types .DriverError ) as exc :
262+ result = tx .run ("RETURN 1 AS n" )
263+ list (result )
264+ check_exception (exc )
265+ raise exc .exception
266+ else :
267+ result = tx .run ("RETURN 1 AS n" )
268+ return list (result )
269+
270+ records = session_ .execute_write (work )
271+ assert len (records ) == 1
272+ self .assertEqual (records , [
273+ types .Record (values = [types .CypherInt (1 )])
274+ ])
275+ else :
276+ raise ValueError (f"Unknown API: { api_ } " )
277+
278+ self ._server .start (
279+ path = self .script_path ("reset_fails_after_pull.script" ),
280+ vars_ = {
281+ "#INVALID_RESPONSE#" : invalid_response_ ,
282+ }
283+ )
284+ auth = types .AuthorizationToken ("basic" , principal = "" ,
285+ credentials = "" )
286+ driver = Driver (self ._backend ,
287+ "bolt://%s" % self ._server .address , auth )
288+ try :
289+ session = driver .session ("r" )
290+ try :
291+ api_call (session )
292+
293+ finally :
294+ session .close ()
295+ # driver should've killed the misbehaving connection
296+ try :
297+ self .assertEqual (
298+ self ._server .count_responses ("<HANGUP>" ), 1
299+ )
300+ finally :
301+ self ._server ._dump ()
302+ finally :
303+ driver .close ()
304+ self ._server .done ()
305+
306+ invalid_responses = (
307+ (
308+ 'S: FAILURE {"code": "Neo.ClientError.General.Unknown", '
309+ '"message": "The driver should ignore this error!"}'
310+ ),
311+ "S: IGNORED" ,
312+ (
313+ "# MIXED \n "
314+ "IF: invalid_responses <= 1\n "
315+ ' S: FAILURE {"code": "Neo.ClientError.General.Unknown", '
316+ '"message": "The driver should ignore this error!"}\n '
317+ "ELSE:\n "
318+ " S: IGNORED\n "
319+ )
320+ )
321+ for invalid_response in invalid_responses :
322+ for api in ("session" , "explicit_tx" , "managed_tx" ):
323+ with self .subTest (response = invalid_response [2 :10 ].strip (),
324+ api = api ):
325+ _test (invalid_response , api )
326+ self ._server .reset ()
0 commit comments