|
| 1 | +import functools |
| 2 | +import logging |
| 3 | +import typing |
| 4 | + |
| 5 | +import asyncpg |
| 6 | +import tenacity |
| 7 | +from sqlalchemy.exc import DBAPIError |
| 8 | + |
| 9 | +from db_try import settings |
| 10 | + |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +def _retry_handler(exception: BaseException) -> bool: |
| 16 | + if ( |
| 17 | + isinstance(exception, DBAPIError) |
| 18 | + and hasattr(exception, "orig") |
| 19 | + and isinstance(exception.orig.__cause__, (asyncpg.SerializationError, asyncpg.PostgresConnectionError)) # type: ignore[union-attr] |
| 20 | + ): |
| 21 | + logger.debug("postgres_retry, retrying") |
| 22 | + return True |
| 23 | + |
| 24 | + logger.debug("postgres_retry, giving up on retry") |
| 25 | + return False |
| 26 | + |
| 27 | + |
| 28 | +def postgres_retry[**P, T]( |
| 29 | + func: typing.Callable[P, typing.Coroutine[None, None, T]], |
| 30 | +) -> typing.Callable[P, typing.Coroutine[None, None, T]]: |
| 31 | + @tenacity.retry( |
| 32 | + stop=tenacity.stop_after_attempt(settings.DB_UTILS_RETRIES_NUMBER), |
| 33 | + wait=tenacity.wait_exponential_jitter(), |
| 34 | + retry=tenacity.retry_if_exception(_retry_handler), |
| 35 | + reraise=True, |
| 36 | + before=tenacity.before_log(logger, logging.DEBUG), |
| 37 | + ) |
| 38 | + @functools.wraps(func) |
| 39 | + async def wrapped_method(*args: P.args, **kwargs: P.kwargs) -> T: |
| 40 | + return await func(*args, **kwargs) |
| 41 | + |
| 42 | + return wrapped_method |
0 commit comments