Skip to content

Commit e263f2b

Browse files
authored
Retry calls to execute as well as to query (#577)
I think originally `execute` didn't support retries because we only retried read-only queries. But now we also retry transaction errors, so we should support it on both. (I modeled the implementation after how I did this in the edgedb test suite's hacked up client: geldata/gel#8249)
1 parent d0eec6c commit e263f2b

File tree

4 files changed

+115
-26
lines changed

4 files changed

+115
-26
lines changed

gel/abstract.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def lower(
9292
class ExecuteContext(typing.NamedTuple):
9393
query: QueryWithArgs
9494
cache: QueryCache
95+
retry_options: typing.Optional[options.RetryOptions]
9596
state: typing.Optional[options.State]
9697
warning_handler: options.WarningHandler
9798
annotations: typing.Dict[str, str]
@@ -187,8 +188,9 @@ class BaseReadOnlyExecutor(abc.ABC):
187188
def _get_query_cache(self) -> QueryCache:
188189
...
189190

191+
@abc.abstractmethod
190192
def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
191-
return None
193+
...
192194

193195
@abc.abstractmethod
194196
def _get_state(self) -> options.State:
@@ -303,6 +305,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
303305
self._execute(ExecuteContext(
304306
query=QueryWithArgs(commands, args, kwargs),
305307
cache=self._get_query_cache(),
308+
retry_options=self._get_retry_options(),
306309
state=self._get_state(),
307310
warning_handler=self._get_warning_handler(),
308311
annotations=self._get_annotations(),
@@ -317,6 +320,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None:
317320
input_language=protocol.InputLanguage.SQL,
318321
),
319322
cache=self._get_query_cache(),
323+
retry_options=self._get_retry_options(),
320324
state=self._get_state(),
321325
warning_handler=self._get_warning_handler(),
322326
annotations=self._get_annotations(),
@@ -438,6 +442,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
438442
await self._execute(ExecuteContext(
439443
query=QueryWithArgs(commands, args, kwargs),
440444
cache=self._get_query_cache(),
445+
retry_options=self._get_retry_options(),
441446
state=self._get_state(),
442447
warning_handler=self._get_warning_handler(),
443448
annotations=self._get_annotations(),
@@ -452,6 +457,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
452457
input_language=protocol.InputLanguage.SQL,
453458
),
454459
cache=self._get_query_cache(),
460+
retry_options=self._get_retry_options(),
455461
state=self._get_state(),
456462
warning_handler=self._get_warning_handler(),
457463
annotations=self._get_annotations(),

gel/base_client.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,18 @@ def is_in_transaction(self) -> bool:
197197
def get_settings(self) -> typing.Dict[str, typing.Any]:
198198
return self._protocol.get_settings()
199199

200-
async def raw_query(self, query_context: abstract.QueryContext):
201-
if self.is_closed():
202-
await self.connect()
203-
200+
async def _retry_operation(self, func, retry_options, ctx):
204201
reconnect = False
205202
i = 0
206-
if self._protocol.is_legacy:
207-
allow_capabilities = enums.Capability.LEGACY_EXECUTE
208-
else:
209-
allow_capabilities = enums.Capability.EXECUTE
210-
ctx = query_context.lower(allow_capabilities=allow_capabilities)
211203
while True:
212204
i += 1
213205
try:
214206
if reconnect:
215207
await self.connect(single_attempt=True)
216-
if self._protocol.is_legacy:
217-
return await self._protocol.legacy_execute_anonymous(ctx)
218-
else:
219-
res = await self._protocol.query(ctx)
220-
if ctx.warnings:
221-
res = query_context.warning_handler(ctx.warnings, res)
222-
return res
208+
return await func()
223209

224210
except errors.EdgeDBError as e:
225-
if query_context.retry_options is None:
211+
if retry_options is None:
226212
raise
227213
if not e.has_tag(errors.SHOULD_RETRY):
228214
raise e
@@ -234,12 +220,37 @@ async def raw_query(self, query_context: abstract.QueryContext):
234220
and not isinstance(e, errors.TransactionConflictError)
235221
):
236222
raise e
237-
rule = query_context.retry_options.get_rule_for_exception(e)
223+
rule = retry_options.get_rule_for_exception(e)
238224
if i >= rule.attempts:
239225
raise e
240226
await self.sleep(rule.backoff(i))
241227
reconnect = self.is_closed()
242228

229+
async def raw_query(self, query_context: abstract.QueryContext):
230+
if self.is_closed():
231+
await self.connect()
232+
233+
reconnect = False
234+
i = 0
235+
if self._protocol.is_legacy:
236+
allow_capabilities = enums.Capability.LEGACY_EXECUTE
237+
else:
238+
allow_capabilities = enums.Capability.EXECUTE
239+
ctx = query_context.lower(allow_capabilities=allow_capabilities)
240+
241+
async def _inner():
242+
if self._protocol.is_legacy:
243+
return await self._protocol.legacy_execute_anonymous(ctx)
244+
else:
245+
res = await self._protocol.query(ctx)
246+
if ctx.warnings:
247+
res = query_context.warning_handler(ctx.warnings, res)
248+
return res
249+
250+
return await self._retry_operation(
251+
_inner, query_context.retry_options, ctx
252+
)
253+
243254
async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
244255
if self._protocol.is_legacy:
245256
if execute_context.query.args or execute_context.query.kwargs:
@@ -253,9 +264,14 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
253264
ctx = execute_context.lower(
254265
allow_capabilities=enums.Capability.EXECUTE
255266
)
256-
res = await self._protocol.execute(ctx)
257-
if ctx.warnings:
258-
res = execute_context.warning_handler(ctx.warnings, res)
267+
async def _inner():
268+
res = await self._protocol.execute(ctx)
269+
if ctx.warnings:
270+
res = execute_context.warning_handler(ctx.warnings, res)
271+
272+
return await self._retry_operation(
273+
_inner, execute_context.retry_options, ctx
274+
)
259275

260276
async def describe(
261277
self, describe_context: abstract.DescribeContext

gel/transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ async def _exit(self, extype, ex):
184184
def _get_query_cache(self) -> abstract.QueryCache:
185185
return self._client._get_query_cache()
186186

187+
def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
188+
return None
189+
187190
def _get_state(self) -> options.State:
188191
return self._client._get_state()
189192

@@ -206,6 +209,7 @@ async def _privileged_execute(self, query: str) -> None:
206209
query=abstract.QueryWithArgs(query, (), {}),
207210
cache=self._get_query_cache(),
208211
state=self._get_state(),
212+
retry_options=self._get_retry_options(),
209213
warning_handler=self._get_warning_handler(),
210214
annotations=self._get_annotations(),
211215
))

tests/test_async_retry.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ class TestAsyncRetry(tb.AsyncQueryTestCase):
5959
};
6060
'''
6161

62-
TEARDOWN = '''
63-
DROP TYPE test::Counter;
64-
'''
65-
6662
async def test_async_retry_01(self):
6763
async for tx in self.client.transaction():
6864
async with tx:
@@ -206,6 +202,73 @@ async def transaction1(client):
206202
self.assertEqual(set(results), {1, 2})
207203
self.assertEqual(iterations, 3)
208204

205+
async def test_async_retry_conflict_nontx_01(self):
206+
await self.execute_nontx_conflict(
207+
'counter_nontx_01',
208+
lambda client, *args, **kwargs: client.query(*args, **kwargs)
209+
)
210+
211+
async def test_async_retry_conflict_nontx_02(self):
212+
await self.execute_nontx_conflict(
213+
'counter_nontx_02',
214+
lambda client, *args, **kwargs: client.execute(*args, **kwargs)
215+
)
216+
217+
async def execute_nontx_conflict(self, name, func):
218+
# Test retries on conflicts in a non-tx setting. We do this
219+
# by having conflicting upserts that are made long-running by
220+
# adding a sys::_sleep call.
221+
#
222+
# Unlike for the tx ones, we don't assert that a retry
223+
# actually was necessary, since that feels fragile in a
224+
# timing-based test like this.
225+
226+
client1 = self.client
227+
client2 = self.make_test_client(database=self.get_database_name())
228+
self.addCleanup(client2.aclose)
229+
230+
await client1.query("SELECT 1")
231+
await client2.query("SELECT 1")
232+
233+
query = '''
234+
SELECT (
235+
INSERT test::Counter {
236+
name := <str>$name,
237+
value := 1,
238+
} UNLESS CONFLICT ON .name
239+
ELSE (
240+
UPDATE test::Counter
241+
SET { value := .value + 1 }
242+
)
243+
).value
244+
ORDER BY sys::_sleep(<int64>$sleep)
245+
THEN <int64>$nonce
246+
'''
247+
248+
await func(client1, query, name=name, sleep=0, nonce=0)
249+
250+
task1 = asyncio.create_task(
251+
func(client1, query, name=name, sleep=5, nonce=1)
252+
)
253+
task2 = asyncio.create_task(
254+
func(client2, query, name=name, sleep=5, nonce=2)
255+
)
256+
257+
results = await asyncio.wait_for(asyncio.gather(
258+
task1,
259+
task2,
260+
return_exceptions=True,
261+
), 20)
262+
263+
excs = [e for e in results if isinstance(e, BaseException)]
264+
if excs:
265+
raise excs[0]
266+
val = await client1.query_single('''
267+
select (select test::Counter filter .name = <str>$name).value
268+
''', name=name)
269+
270+
self.assertEqual(val, 3)
271+
209272
async def test_async_transaction_interface_errors(self):
210273
with self.assertRaisesRegex(
211274
AttributeError,

0 commit comments

Comments
 (0)