Skip to content

Retry calls to execute as well as to query #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gel/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def lower(
class ExecuteContext(typing.NamedTuple):
query: QueryWithArgs
cache: QueryCache
retry_options: typing.Optional[options.RetryOptions]
state: typing.Optional[options.State]
warning_handler: options.WarningHandler
annotations: typing.Dict[str, str]
Expand Down Expand Up @@ -187,8 +188,9 @@ class BaseReadOnlyExecutor(abc.ABC):
def _get_query_cache(self) -> QueryCache:
...

@abc.abstractmethod
def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
return None
...

@abc.abstractmethod
def _get_state(self) -> options.State:
Expand Down Expand Up @@ -303,6 +305,7 @@ def execute(self, commands: str, *args, **kwargs) -> None:
self._execute(ExecuteContext(
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand All @@ -317,6 +320,7 @@ def execute_sql(self, commands: str, *args, **kwargs) -> None:
input_language=protocol.InputLanguage.SQL,
),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand Down Expand Up @@ -438,6 +442,7 @@ async def execute(self, commands: str, *args, **kwargs) -> None:
await self._execute(ExecuteContext(
query=QueryWithArgs(commands, args, kwargs),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand All @@ -452,6 +457,7 @@ async def execute_sql(self, commands: str, *args, **kwargs) -> None:
input_language=protocol.InputLanguage.SQL,
),
cache=self._get_query_cache(),
retry_options=self._get_retry_options(),
state=self._get_state(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
Expand Down
58 changes: 37 additions & 21 deletions gel/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,32 +197,18 @@ def is_in_transaction(self) -> bool:
def get_settings(self) -> typing.Dict[str, typing.Any]:
return self._protocol.get_settings()

async def raw_query(self, query_context: abstract.QueryContext):
if self.is_closed():
await self.connect()

async def _retry_operation(self, func, retry_options, ctx):
reconnect = False
i = 0
if self._protocol.is_legacy:
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
allow_capabilities = enums.Capability.EXECUTE
ctx = query_context.lower(allow_capabilities=allow_capabilities)
while True:
i += 1
try:
if reconnect:
await self.connect(single_attempt=True)
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(ctx)
else:
res = await self._protocol.query(ctx)
if ctx.warnings:
res = query_context.warning_handler(ctx.warnings, res)
return res
return await func()

except errors.EdgeDBError as e:
if query_context.retry_options is None:
if retry_options is None:
raise
if not e.has_tag(errors.SHOULD_RETRY):
raise e
Expand All @@ -234,12 +220,37 @@ async def raw_query(self, query_context: abstract.QueryContext):
and not isinstance(e, errors.TransactionConflictError)
):
raise e
rule = query_context.retry_options.get_rule_for_exception(e)
rule = retry_options.get_rule_for_exception(e)
if i >= rule.attempts:
raise e
await self.sleep(rule.backoff(i))
reconnect = self.is_closed()

async def raw_query(self, query_context: abstract.QueryContext):
if self.is_closed():
await self.connect()

reconnect = False
i = 0
if self._protocol.is_legacy:
allow_capabilities = enums.Capability.LEGACY_EXECUTE
else:
allow_capabilities = enums.Capability.EXECUTE
ctx = query_context.lower(allow_capabilities=allow_capabilities)

async def _inner():
if self._protocol.is_legacy:
return await self._protocol.legacy_execute_anonymous(ctx)
else:
res = await self._protocol.query(ctx)
if ctx.warnings:
res = query_context.warning_handler(ctx.warnings, res)
return res

return await self._retry_operation(
_inner, query_context.retry_options, ctx
)

async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
if self._protocol.is_legacy:
if execute_context.query.args or execute_context.query.kwargs:
Expand All @@ -253,9 +264,14 @@ async def _execute(self, execute_context: abstract.ExecuteContext) -> None:
ctx = execute_context.lower(
allow_capabilities=enums.Capability.EXECUTE
)
res = await self._protocol.execute(ctx)
if ctx.warnings:
res = execute_context.warning_handler(ctx.warnings, res)
async def _inner():
res = await self._protocol.execute(ctx)
if ctx.warnings:
res = execute_context.warning_handler(ctx.warnings, res)

return await self._retry_operation(
_inner, execute_context.retry_options, ctx
)

async def describe(
self, describe_context: abstract.DescribeContext
Expand Down
4 changes: 4 additions & 0 deletions gel/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ async def _exit(self, extype, ex):
def _get_query_cache(self) -> abstract.QueryCache:
return self._client._get_query_cache()

def _get_retry_options(self) -> typing.Optional[options.RetryOptions]:
return None

def _get_state(self) -> options.State:
return self._client._get_state()

Expand All @@ -206,6 +209,7 @@ async def _privileged_execute(self, query: str) -> None:
query=abstract.QueryWithArgs(query, (), {}),
cache=self._get_query_cache(),
state=self._get_state(),
retry_options=self._get_retry_options(),
warning_handler=self._get_warning_handler(),
annotations=self._get_annotations(),
))
Expand Down
71 changes: 67 additions & 4 deletions tests/test_async_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class TestAsyncRetry(tb.AsyncQueryTestCase):
};
'''

TEARDOWN = '''
DROP TYPE test::Counter;
'''

async def test_async_retry_01(self):
async for tx in self.client.transaction():
async with tx:
Expand Down Expand Up @@ -206,6 +202,73 @@ async def transaction1(client):
self.assertEqual(set(results), {1, 2})
self.assertEqual(iterations, 3)

async def test_async_retry_conflict_nontx_01(self):
await self.execute_nontx_conflict(
'counter_nontx_01',
lambda client, *args, **kwargs: client.query(*args, **kwargs)
)

async def test_async_retry_conflict_nontx_02(self):
await self.execute_nontx_conflict(
'counter_nontx_02',
lambda client, *args, **kwargs: client.execute(*args, **kwargs)
)

async def execute_nontx_conflict(self, name, func):
# Test retries on conflicts in a non-tx setting. We do this
# by having conflicting upserts that are made long-running by
# adding a sys::_sleep call.
#
# Unlike for the tx ones, we don't assert that a retry
# actually was necessary, since that feels fragile in a
# timing-based test like this.

client1 = self.client
client2 = self.make_test_client(database=self.get_database_name())
self.addCleanup(client2.aclose)

await client1.query("SELECT 1")
await client2.query("SELECT 1")

query = '''
SELECT (
INSERT test::Counter {
name := <str>$name,
value := 1,
} UNLESS CONFLICT ON .name
ELSE (
UPDATE test::Counter
SET { value := .value + 1 }
)
).value
ORDER BY sys::_sleep(<int64>$sleep)
THEN <int64>$nonce
'''

await func(client1, query, name=name, sleep=0, nonce=0)

task1 = asyncio.create_task(
func(client1, query, name=name, sleep=5, nonce=1)
)
task2 = asyncio.create_task(
func(client2, query, name=name, sleep=5, nonce=2)
)

results = await asyncio.wait_for(asyncio.gather(
task1,
task2,
return_exceptions=True,
), 20)

excs = [e for e in results if isinstance(e, BaseException)]
if excs:
raise excs[0]
val = await client1.query_single('''
select (select test::Counter filter .name = <str>$name).value
''', name=name)

self.assertEqual(val, 3)

async def test_async_transaction_interface_errors(self):
with self.assertRaisesRegex(
AttributeError,
Expand Down
Loading