Skip to content

Updated InMemoryBroker #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import inspect
from collections import OrderedDict
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, get_type_hints
from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints

from taskiq_dependencies import DependencyGraph

Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__( # noqa: WPS211
log_collector_format=logs_format or WorkerArgs.log_collector_format,
),
)
self._running_tasks: "Set[asyncio.Task[Any]]" = set()

async def kick(self, message: BrokerMessage) -> None:
"""
Expand All @@ -128,6 +130,7 @@ async def kick(self, message: BrokerMessage) -> None:
target_task = self.available_tasks.get(message.task_name)
if target_task is None:
raise TaskiqError("Unknown task.")

if not self.receiver.dependency_graphs.get(target_task.task_name):
self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph(
target_task.original_func,
Expand All @@ -141,7 +144,9 @@ async def kick(self, message: BrokerMessage) -> None:
target_task.original_func,
)

await self.receiver.callback(message=message)
task = asyncio.create_task(self.receiver.callback(message=message))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)

def listen(self) -> AsyncGenerator[BrokerMessage, None]:
"""
Expand Down
87 changes: 87 additions & 0 deletions tests/brokers/test_inmemory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import asyncio
import uuid

import pytest

from taskiq import InMemoryBroker
from taskiq.events import TaskiqEvents
from taskiq.state import TaskiqState


@pytest.mark.anyio
async def test_inmemory_success() -> None:
broker = InMemoryBroker()
test_val = uuid.uuid4().hex

@broker.task
async def task() -> str:
return test_val

kicked = await task.kiq()
result = await kicked.wait_result()
assert result.return_value == test_val
assert not broker._running_tasks


@pytest.mark.anyio
async def test_cannot_listen() -> None:
broker = InMemoryBroker()

with pytest.raises(RuntimeError):
async for _ in broker.listen():
pass


@pytest.mark.anyio
async def test_startup() -> None:
broker = InMemoryBroker()
test_value = uuid.uuid4().hex

@broker.on_event(TaskiqEvents.WORKER_STARTUP)
async def _w_startup(state: TaskiqState) -> None:
state.from_worker = test_value

@broker.on_event(TaskiqEvents.CLIENT_STARTUP)
async def _c_startup(state: TaskiqState) -> None:
state.from_client = test_value

await broker.startup()

assert broker.state.from_worker == test_value
assert broker.state.from_client == test_value


@pytest.mark.anyio
async def test_shutdown() -> None:
broker = InMemoryBroker()
test_value = uuid.uuid4().hex

@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
async def _w_startup(state: TaskiqState) -> None:
state.from_worker = test_value

@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
async def _c_startup(state: TaskiqState) -> None:
state.from_client = test_value

await broker.shutdown()

assert broker.state.from_worker == test_value
assert broker.state.from_client == test_value


@pytest.mark.anyio
async def test_execution() -> None:
broker = InMemoryBroker()
test_value = uuid.uuid4().hex

@broker.task
async def test_task() -> str:
await asyncio.sleep(0.5)
return test_value

task = await test_task.kiq()
assert not await task.is_ready()

result = await task.wait_result()
assert result.return_value == test_value