|
8 | 8 | from datetime import datetime, timezone
|
9 | 9 | from functools import partial
|
10 | 10 |
|
| 11 | +from discord.errors import Forbidden |
| 12 | + |
11 | 13 | from pydis_core.utils import logging
|
| 14 | +from pydis_core.utils.error_handling import handle_forbidden_from_block |
12 | 15 |
|
13 | 16 | _background_tasks: set[asyncio.Task] = set()
|
14 | 17 |
|
@@ -77,7 +80,7 @@ def schedule(self, task_id: abc.Hashable, coroutine: abc.Coroutine) -> None:
|
77 | 80 | coroutine.close()
|
78 | 81 | return
|
79 | 82 |
|
80 |
| - task = asyncio.create_task(coroutine, name=f"{self.name}_{task_id}") |
| 83 | + task = asyncio.create_task(_coro_wrapper(coroutine), name=f"{self.name}_{task_id}") |
81 | 84 | task.add_done_callback(partial(self._task_done_callback, task_id))
|
82 | 85 |
|
83 | 86 | self._scheduled_tasks[task_id] = task
|
@@ -238,21 +241,29 @@ def create_task(
|
238 | 241 | asyncio.Task: The wrapped task.
|
239 | 242 | """
|
240 | 243 | if event_loop is not None:
|
241 |
| - task = event_loop.create_task(coro, **kwargs) |
| 244 | + task = event_loop.create_task(_coro_wrapper(coro), **kwargs) |
242 | 245 | else:
|
243 |
| - task = asyncio.create_task(coro, **kwargs) |
| 246 | + task = asyncio.create_task(_coro_wrapper(coro), **kwargs) |
244 | 247 |
|
245 | 248 | _background_tasks.add(task)
|
246 | 249 | task.add_done_callback(_background_tasks.discard)
|
247 | 250 | task.add_done_callback(partial(_log_task_exception, suppressed_exceptions=suppressed_exceptions))
|
248 | 251 | return task
|
249 | 252 |
|
250 | 253 |
|
| 254 | +async def _coro_wrapper(coro: abc.Coroutine[typing.Any, typing.Any, TASK_RETURN]) -> None: |
| 255 | + """Wraps `coro` in a try/except block that will handle 90001 Forbidden errors.""" |
| 256 | + try: |
| 257 | + await coro |
| 258 | + except Forbidden as e: |
| 259 | + await handle_forbidden_from_block(e) |
| 260 | + |
| 261 | + |
251 | 262 | def _log_task_exception(task: asyncio.Task, *, suppressed_exceptions: tuple[type[Exception], ...]) -> None:
|
252 |
| - """Retrieve and log the exception raised in ``task`` if one exists.""" |
| 263 | + """Retrieve and log the exception raised in ``task``, if one exists and it's not suppressed.""" |
253 | 264 | with contextlib.suppress(asyncio.CancelledError):
|
254 | 265 | exception = task.exception()
|
255 |
| - # Log the exception if one exists. |
| 266 | + # Log the exception if one exists and it's not suppressed/handled. |
256 | 267 | if exception and not isinstance(exception, suppressed_exceptions):
|
257 | 268 | log = logging.get_logger(__name__)
|
258 | 269 | log.error(f"Error in task {task.get_name()} {id(task)}!", exc_info=exception)
|
0 commit comments