Skip to content

Commit 10a1d15

Browse files
committed
Fix async handler detection for partial callables
1 parent 95feb52 commit 10a1d15

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

src/core/interfaces/command_service.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3+
from asyncio import iscoroutinefunction as asyncio_iscoroutinefunction
34
from collections.abc import Awaitable, Callable
4-
from inspect import iscoroutinefunction
55
from typing import Any
66

77
from src.core.domain.processed_result import ProcessedResult
@@ -13,11 +13,22 @@
1313
def _is_async_callable(candidate: Any) -> bool:
1414
"""Return ``True`` when *candidate* is an awaitable callable."""
1515

16-
if iscoroutinefunction(candidate): # Fast path for async callables
16+
if asyncio_iscoroutinefunction(candidate): # Handles partials and decorated callables
17+
return True
18+
19+
func_attr = getattr(candidate, "func", None)
20+
if func_attr and asyncio_iscoroutinefunction(func_attr):
1721
return True
1822

1923
call_method = getattr(candidate, "__call__", None)
20-
return bool(call_method and iscoroutinefunction(call_method))
24+
if not call_method:
25+
return False
26+
27+
if asyncio_iscoroutinefunction(call_method):
28+
return True
29+
30+
bound_function = getattr(call_method, "__func__", None)
31+
return bool(bound_function and asyncio_iscoroutinefunction(bound_function))
2132

2233

2334
class FunctionCommandService(ICommandService):

tests/unit/core/test_command_service_module.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import pytest
24
from src.core.domain.processed_result import ProcessedResult
35
from src.core.interfaces.command_service import ensure_command_service
@@ -48,6 +50,28 @@ async def handler(messages: list[str], session_id: str) -> ProcessedResult:
4850
assert result.command_results == ["session"]
4951

5052

53+
@pytest.mark.asyncio
54+
async def test_ensure_command_service_accepts_partial_async_callable() -> None:
55+
async def handler(
56+
messages: list[str], session_id: str, prefix: str
57+
) -> ProcessedResult:
58+
return ProcessedResult(
59+
modified_messages=[f"{prefix}:{value}" for value in messages],
60+
command_executed=bool(messages),
61+
command_results=[session_id],
62+
)
63+
64+
partial_handler = partial(handler, prefix="partial")
65+
66+
validated_service = ensure_command_service(partial_handler)
67+
68+
result = await validated_service.process_commands(["message"], "session")
69+
70+
assert result.modified_messages == ["partial:message"]
71+
assert result.command_executed is True
72+
assert result.command_results == ["session"]
73+
74+
5175
def test_ensure_command_service_rejects_sync_callable() -> None:
5276
def handler(messages: list[str], session_id: str) -> ProcessedResult:
5377
return ProcessedResult(

0 commit comments

Comments
 (0)