Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions fastapi_utils/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import warnings
from functools import wraps
from traceback import format_exception
from typing import Any, Callable, Coroutine, Union
Expand All @@ -10,7 +11,26 @@

NoArgsNoReturnFuncT = Callable[[], None]
NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]]
NoArgsNoReturnDecorator = Callable[[Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]], NoArgsNoReturnAsyncFuncT]
ExcArgNoReturnFuncT = Callable[[Exception], None]
ExcArgNoReturnAsyncFuncT = Callable[[Exception], Coroutine[Any, Any, None]]
NoArgsNoReturnAnyFuncT = Union[NoArgsNoReturnFuncT, NoArgsNoReturnAsyncFuncT]
ExcArgNoReturnAnyFuncT = Union[ExcArgNoReturnFuncT, ExcArgNoReturnAsyncFuncT]
NoArgsNoReturnDecorator = Callable[[NoArgsNoReturnAnyFuncT], NoArgsNoReturnAsyncFuncT]


async def _handle_func(func: NoArgsNoReturnAnyFuncT) -> None:
if asyncio.iscoroutinefunction(func):
await func()
else:
await run_in_threadpool(func)


async def _handle_exc(exc: Exception, on_exception: ExcArgNoReturnAnyFuncT | None) -> None:
if on_exception:
if asyncio.iscoroutinefunction(on_exception):
await on_exception(exc)
else:
await run_in_threadpool(on_exception, exc)


def repeat_every(
Expand All @@ -20,6 +40,8 @@ def repeat_every(
logger: logging.Logger | None = None,
raise_exceptions: bool = False,
max_repetitions: int | None = None,
on_complete: NoArgsNoReturnAnyFuncT | None = None,
on_exception: ExcArgNoReturnAnyFuncT | None = None,
) -> NoArgsNoReturnDecorator:
"""
This function returns a decorator that modifies a function so it is periodically re-executed after its first call.
Expand All @@ -34,47 +56,62 @@ def repeat_every(
wait_first: float (default None)
If not None, the function will wait for the given duration before the first call
logger: Optional[logging.Logger] (default None)
Warning: This parameter is deprecated and will be removed in the 1.0 release.
The logger to use to log any exceptions raised by calls to the decorated function.
If not provided, exceptions will not be logged by this function (though they may be handled by the event loop).
raise_exceptions: bool (default False)
Warning: This parameter is deprecated and will be removed in the 1.0 release.
If True, errors raised by the decorated function will be raised to the event loop's exception handler.
Note that if an error is raised, the repeated execution will stop.
Otherwise, exceptions are just logged and the execution continues to repeat.
See https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.set_exception_handler for more info.
max_repetitions: Optional[int] (default None)
The maximum number of times to call the repeated function. If `None`, the function is repeated forever.
on_complete: Optional[Callable[[], None]] (default None)
A function to call after the final repetition of the decorated function.
on_exception: Optional[Callable[[Exception], None]] (default None)
A function to call when an exception is raised by the decorated function.
"""

def decorator(func: NoArgsNoReturnAsyncFuncT | NoArgsNoReturnFuncT) -> NoArgsNoReturnAsyncFuncT:
def decorator(func: NoArgsNoReturnAnyFuncT) -> NoArgsNoReturnAsyncFuncT:
"""
Converts the decorated function into a repeated, periodically-called version of itself.
"""
is_coroutine = asyncio.iscoroutinefunction(func)

@wraps(func)
async def wrapped() -> None:
repetitions = 0

async def loop() -> None:
nonlocal repetitions
if wait_first is not None:
await asyncio.sleep(wait_first)

repetitions = 0
while max_repetitions is None or repetitions < max_repetitions:
try:
if is_coroutine:
await func() # type: ignore
else:
await run_in_threadpool(func)
await _handle_func(func)

except Exception as exc:
if logger is not None:
warnings.warn(
"'logger' is to be deprecated in favor of 'on_exception' in the 1.0 release.",
DeprecationWarning,
)
formatted_exception = "".join(format_exception(type(exc), exc, exc.__traceback__))
logger.error(formatted_exception)
if raise_exceptions:
warnings.warn(
"'raise_excpeions' is to be deprecated in favor of 'on_exception' in the 1.0 release.",
DeprecationWarning,
)
raise exc
await _handle_exc(exc, on_exception)

repetitions += 1
await asyncio.sleep(seconds)

await loop()
if on_complete:
await _handle_func(on_complete)

asyncio.ensure_future(loop())

return wrapped

Expand Down
188 changes: 121 additions & 67 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import sys
from typing import TYPE_CHECKING, NoReturn

Expand Down Expand Up @@ -37,42 +38,92 @@ def wait_first(seconds: float) -> float:
class TestRepeatEveryBase:
def setup_method(self) -> None:
self.counter = 0
self.completed = asyncio.Event()

def increase_counter(self) -> None:
self.counter += 1

async def increase_counter_async(self) -> None:
self.increase_counter()

def loop_completed(self) -> None:
self.completed.set()

async def loop_completed_async(self) -> None:
self.loop_completed()

def kill_loop(self, exc: Exception) -> None:
self.completed.set()
raise exc

async def kill_loop_async(self, exc: Exception) -> None:
self.kill_loop(exc)

def continue_loop(self, exc: Exception) -> None:
return

async def continue_loop_async(self, exc: Exception) -> None:
self.continue_loop(exc)

def raise_exc(self) -> NoReturn:
self.increase_counter()
raise ValueError("error")

async def raise_exc_async(self) -> NoReturn:
self.raise_exc()

class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter)
def increase_counter_task(self, is_async: bool, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, on_complete=self.loop_completed)
func = self.increase_counter_async if is_async else self.increase_counter
return decorator(func)

@pytest.fixture
def wait_first_increase_counter_task(
self, seconds: float, max_repetitions: int, wait_first: float
self, is_async: bool, seconds: float, max_repetitions: int, wait_first: float
) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first)
return decorator(self.increase_counter)
decorator = repeat_every(
seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first, on_complete=self.loop_completed
)
func = self.increase_counter_async if is_async else self.increase_counter
return decorator(func)

@staticmethod
@pytest.fixture
def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, max_repetitions=max_repetitions)
def raise_exc() -> NoReturn:
raise ValueError("error")
def stop_on_exception_task(self, is_async: bool, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
on_complete = self.loop_completed_async if is_async else self.loop_completed
on_exception = self.kill_loop_async if is_async else self.kill_loop
decorator = repeat_every(
seconds=seconds,
max_repetitions=max_repetitions,
on_complete=on_complete,
on_exception=on_exception,
)
func = self.raise_exc_async if is_async else self.raise_exc
return decorator(func)

return raise_exc

@staticmethod
@pytest.fixture
def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, raise_exceptions=True)
def raise_exc() -> NoReturn:
raise ValueError("error")
def suppressed_exception_task(
self, is_async: bool, seconds: float, max_repetitions: int
) -> NoArgsNoReturnAsyncFuncT:
on_complete = self.loop_completed_async if is_async else self.loop_completed
on_exception = self.continue_loop_async if is_async else self.continue_loop
decorator = repeat_every(
seconds=seconds,
max_repetitions=max_repetitions,
on_complete=on_complete,
on_exception=on_exception,
)
func = self.raise_exc_async if is_async else self.raise_exc
return decorator(func)

return raise_exc

class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def is_async(self) -> bool:
return False

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions(
self,
Expand All @@ -82,73 +133,62 @@ async def test_max_repetitions(
increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions_and_wait_first(
self,
asyncio_sleep_mock: AsyncMock,
seconds: float,
max_repetitions: int,
wait_first: float,
wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await wait_first_increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)

@pytest.mark.asyncio
async def test_raise_exceptions_false(
self, seconds: float, max_repetitions: int, raising_task: NoArgsNoReturnAsyncFuncT
@pytest.mark.timeout(1)
async def test_stop_loop_on_exc(
self,
stop_on_exception_task: NoArgsNoReturnAsyncFuncT,
) -> None:
try:
await raising_task()
except ValueError as e:
pytest.fail(f"{self.test_raise_exceptions_false.__name__} raised an exception: {e}")
await stop_on_exception_task()
await self.completed.wait()

assert self.counter == 1

@pytest.mark.asyncio
async def test_raise_exceptions_true(
self, seconds: float, suppressed_exception_task: NoArgsNoReturnAsyncFuncT
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_continue_loop_on_exc(
self,
asyncio_sleep_mock: AsyncMock,
seconds: float,
max_repetitions: int,
suppressed_exception_task: NoArgsNoReturnAsyncFuncT,
) -> None:
with pytest.raises(ValueError):
await suppressed_exception_task()

await suppressed_exception_task()
await self.completed.wait()

class TestRepeatEveryWithAsynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter)

@pytest.fixture
def wait_first_increase_counter_task(
self, seconds: float, max_repetitions: int, wait_first: float
) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first)
return decorator(self.increase_counter)

@staticmethod
@pytest.fixture
def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, max_repetitions=max_repetitions)
async def raise_exc() -> NoReturn:
raise ValueError("error")
assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)

return raise_exc

@staticmethod
class TestRepeatEveryWithAsynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, raise_exceptions=True)
async def raise_exc() -> NoReturn:
raise ValueError("error")

return raise_exc
def is_async(self) -> bool:
return True

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions(
self,
Expand All @@ -158,11 +198,13 @@ async def test_max_repetitions(
increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions_and_wait_first(
self,
Expand All @@ -172,22 +214,34 @@ async def test_max_repetitions_and_wait_first(
wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await wait_first_increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)

@pytest.mark.asyncio
async def test_raise_exceptions_false(
self, seconds: float, max_repetitions: int, raising_task: NoArgsNoReturnAsyncFuncT
@pytest.mark.timeout(1)
async def test_stop_loop_on_exc(
self,
stop_on_exception_task: NoArgsNoReturnAsyncFuncT,
) -> None:
try:
await raising_task()
except ValueError as e:
pytest.fail(f"{self.test_raise_exceptions_false.__name__} raised an exception: {e}")
await stop_on_exception_task()
await self.completed.wait()

assert self.counter == 1

@pytest.mark.asyncio
async def test_raise_exceptions_true(
self, seconds: float, suppressed_exception_task: NoArgsNoReturnAsyncFuncT
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_continue_loop_on_exc(
self,
asyncio_sleep_mock: AsyncMock,
seconds: float,
max_repetitions: int,
suppressed_exception_task: NoArgsNoReturnAsyncFuncT,
) -> None:
with pytest.raises(ValueError):
await suppressed_exception_task()
await suppressed_exception_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)