Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions fastapi_utils/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def repeat_every(
logger: logging.Logger | None = None,
raise_exceptions: bool = False,
max_repetitions: int | None = None,
on_complete: NoArgsNoReturnFuncT | None = None
) -> NoArgsNoReturnDecorator:
"""
This function returns a decorator that modifies a function so it is periodically re-executed after its first call.
Expand Down Expand Up @@ -53,28 +54,35 @@ def decorator(func: NoArgsNoReturnAsyncFuncT | NoArgsNoReturnFuncT) -> NoArgsNoR

@wraps(func)
async def wrapped() -> None:
repetitions = 0

async def loop() -> None:
nonlocal repetitions
if wait_first is not None:
await asyncio.sleep(wait_first)

repetitions = 0
while max_repetitions is None or repetitions < max_repetitions:
try:
if is_coroutine:
await func() # type: ignore
else:
await run_in_threadpool(func)

except Exception as exc:
if logger is not None:
formatted_exception = "".join(format_exception(type(exc), exc, exc.__traceback__))
logger.error(formatted_exception)
if raise_exceptions:
raise exc

repetitions += 1
await asyncio.sleep(seconds)

await loop()
if on_complete:
if asyncio.iscoroutinefunction(on_complete):
await on_complete()
else:
await run_in_threadpool(on_complete)

asyncio.ensure_future(loop())

return wrapped

Expand Down
70 changes: 25 additions & 45 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import asyncio
from typing import TYPE_CHECKING, NoReturn

if TYPE_CHECKING:
Expand Down Expand Up @@ -37,42 +38,43 @@ def wait_first(seconds: float) -> float:
class TestRepeatEveryBase:
def setup_method(self) -> None:
self.counter = 0
self.completed = asyncio.Event()

def increase_counter(self) -> None:
self.counter += 1

def loop_completed(self) -> None:
self.completed.set()

def raise_exc(self) -> NoReturn:
raise ValueError("error")

class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter)
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, on_complete=self.loop_completed)
return decorator(self.increase_counter)

@pytest.fixture
def wait_first_increase_counter_task(
self, seconds: float, max_repetitions: int, wait_first: float
) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first)
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first, on_complete=self.loop_completed)
return decorator(self.increase_counter)

@staticmethod
@pytest.fixture
def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, max_repetitions=max_repetitions)
def raise_exc() -> NoReturn:
raise ValueError("error")
def raising_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, on_complete=self.loop_completed)
return decorator(self.raise_exc)

return raise_exc

@staticmethod
@pytest.fixture
def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, raise_exceptions=True)
def raise_exc() -> NoReturn:
raise ValueError("error")
def suppressed_exception_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, raise_exceptions=True, on_complete=self.loop_completed)
return decorator(self.raise_exc)

return raise_exc

class TestRepeatEveryWithSynchronousFunction(TestRepeatEveryBase):
@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions(
self,
Expand All @@ -82,11 +84,13 @@ async def test_max_repetitions(
increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions_and_wait_first(
self,
Expand All @@ -97,6 +101,7 @@ async def test_max_repetitions_and_wait_first(
wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await wait_first_increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)
Expand All @@ -119,36 +124,8 @@ async def test_raise_exceptions_true(


class TestRepeatEveryWithAsynchronousFunction(TestRepeatEveryBase):
@pytest.fixture
def increase_counter_task(self, seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
return repeat_every(seconds=seconds, max_repetitions=max_repetitions)(self.increase_counter)

@pytest.fixture
def wait_first_increase_counter_task(
self, seconds: float, max_repetitions: int, wait_first: float
) -> NoArgsNoReturnAsyncFuncT:
decorator = repeat_every(seconds=seconds, max_repetitions=max_repetitions, wait_first=wait_first)
return decorator(self.increase_counter)

@staticmethod
@pytest.fixture
def raising_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, max_repetitions=max_repetitions)
async def raise_exc() -> NoReturn:
raise ValueError("error")

return raise_exc

@staticmethod
@pytest.fixture
def suppressed_exception_task(seconds: float, max_repetitions: int) -> NoArgsNoReturnAsyncFuncT:
@repeat_every(seconds=seconds, raise_exceptions=True)
async def raise_exc() -> NoReturn:
raise ValueError("error")

return raise_exc

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions(
self,
Expand All @@ -158,11 +135,13 @@ async def test_max_repetitions(
increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls(max_repetitions * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
@patch("asyncio.sleep")
async def test_max_repetitions_and_wait_first(
self,
Expand All @@ -172,6 +151,7 @@ async def test_max_repetitions_and_wait_first(
wait_first_increase_counter_task: NoArgsNoReturnAsyncFuncT,
) -> None:
await wait_first_increase_counter_task()
await self.completed.wait()

assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)
Expand Down