Skip to content

Commit e66f3aa

Browse files
authored
Updated InMemoryBroker (#84)
1 parent 8cac3a6 commit e66f3aa

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import asyncio
12
import inspect
23
from collections import OrderedDict
3-
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, get_type_hints
4+
from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints
45

56
from taskiq_dependencies import DependencyGraph
67

@@ -114,6 +115,7 @@ def __init__( # noqa: WPS211
114115
log_collector_format=logs_format or WorkerArgs.log_collector_format,
115116
),
116117
)
118+
self._running_tasks: "Set[asyncio.Task[Any]]" = set()
117119

118120
async def kick(self, message: BrokerMessage) -> None:
119121
"""
@@ -128,6 +130,7 @@ async def kick(self, message: BrokerMessage) -> None:
128130
target_task = self.available_tasks.get(message.task_name)
129131
if target_task is None:
130132
raise TaskiqError("Unknown task.")
133+
131134
if not self.receiver.dependency_graphs.get(target_task.task_name):
132135
self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph(
133136
target_task.original_func,
@@ -141,7 +144,9 @@ async def kick(self, message: BrokerMessage) -> None:
141144
target_task.original_func,
142145
)
143146

144-
await self.receiver.callback(message=message)
147+
task = asyncio.create_task(self.receiver.callback(message=message))
148+
self._running_tasks.add(task)
149+
task.add_done_callback(self._running_tasks.discard)
145150

146151
def listen(self) -> AsyncGenerator[BrokerMessage, None]:
147152
"""

tests/brokers/test_inmemory.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
import uuid
3+
4+
import pytest
5+
6+
from taskiq import InMemoryBroker
7+
from taskiq.events import TaskiqEvents
8+
from taskiq.state import TaskiqState
9+
10+
11+
@pytest.mark.anyio
12+
async def test_inmemory_success() -> None:
13+
broker = InMemoryBroker()
14+
test_val = uuid.uuid4().hex
15+
16+
@broker.task
17+
async def task() -> str:
18+
return test_val
19+
20+
kicked = await task.kiq()
21+
result = await kicked.wait_result()
22+
assert result.return_value == test_val
23+
assert not broker._running_tasks
24+
25+
26+
@pytest.mark.anyio
27+
async def test_cannot_listen() -> None:
28+
broker = InMemoryBroker()
29+
30+
with pytest.raises(RuntimeError):
31+
async for _ in broker.listen():
32+
pass
33+
34+
35+
@pytest.mark.anyio
36+
async def test_startup() -> None:
37+
broker = InMemoryBroker()
38+
test_value = uuid.uuid4().hex
39+
40+
@broker.on_event(TaskiqEvents.WORKER_STARTUP)
41+
async def _w_startup(state: TaskiqState) -> None:
42+
state.from_worker = test_value
43+
44+
@broker.on_event(TaskiqEvents.CLIENT_STARTUP)
45+
async def _c_startup(state: TaskiqState) -> None:
46+
state.from_client = test_value
47+
48+
await broker.startup()
49+
50+
assert broker.state.from_worker == test_value
51+
assert broker.state.from_client == test_value
52+
53+
54+
@pytest.mark.anyio
55+
async def test_shutdown() -> None:
56+
broker = InMemoryBroker()
57+
test_value = uuid.uuid4().hex
58+
59+
@broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)
60+
async def _w_startup(state: TaskiqState) -> None:
61+
state.from_worker = test_value
62+
63+
@broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN)
64+
async def _c_startup(state: TaskiqState) -> None:
65+
state.from_client = test_value
66+
67+
await broker.shutdown()
68+
69+
assert broker.state.from_worker == test_value
70+
assert broker.state.from_client == test_value
71+
72+
73+
@pytest.mark.anyio
74+
async def test_execution() -> None:
75+
broker = InMemoryBroker()
76+
test_value = uuid.uuid4().hex
77+
78+
@broker.task
79+
async def test_task() -> str:
80+
await asyncio.sleep(0.5)
81+
return test_value
82+
83+
task = await test_task.kiq()
84+
assert not await task.is_ready()
85+
86+
result = await task.wait_result()
87+
assert result.return_value == test_value

0 commit comments

Comments
 (0)