Skip to content
27 changes: 25 additions & 2 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ async def push_notification_callback() -> None:
)

except Exception:
logger.exception('Agent execution failed')
await self._handle_execution_failure(producer_task, queue)
raise
finally:
if interrupted_or_non_blocking:
Expand Down Expand Up @@ -392,6 +392,10 @@ async def on_message_send_stream(
bg_task.set_name(f'background_consume:{task_id}')
self._track_background_task(bg_task)
raise
except Exception:
# If the consumer fails (e.g. database error), we must cleanup.
await self._handle_execution_failure(producer_task, queue)
raise
finally:
cleanup_task = asyncio.create_task(
self._cleanup_producer(producer_task, task_id)
Expand Down Expand Up @@ -429,13 +433,32 @@ def _on_done(completed: asyncio.Task) -> None:

task.add_done_callback(_on_done)

async def _handle_execution_failure(
self, producer_task: asyncio.Task, queue: EventQueue
) -> None:
"""Cancels the producer and closes the queue immediately on failure."""
logger.exception('Agent execution failed')
# If the consumer fails, we must cancel the producer to prevent it from hanging
# on queue operations (e.g., waiting for the queue to drain).
producer_task.cancel()
# Force the queue to close immediately, discarding any pending events.
# This ensures that any producers waiting on the queue are unblocked.
await queue.close(immediate=True)

async def _cleanup_producer(
self,
producer_task: asyncio.Task,
task_id: str,
) -> None:
"""Cleans up the agent execution task and queue manager entry."""
await producer_task
try:
await producer_task
except asyncio.CancelledError:
logger.debug(
'Producer task %s was cancelled during cleanup', task_id
)
except Exception:
logger.exception('Producer task %s failed during cleanup', task_id)
await self._queue_manager.close(task_id)
async with self._running_agents_lock:
self._running_agents.pop(task_id, None)
Expand Down
168 changes: 168 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,3 +2644,171 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
f'Task {task_id} was specified but does not exist'
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_task'
context_id = 'error_cleanup_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_cleanup',
parts=[],
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)

async def raise_error_gen(_consumer):
# Raise an exception to simulate consumer failure
raise ValueError('Consumer failed!')
yield # unreachable

mock_result_aggregator_instance.consume_and_emit.side_effect = (
raise_error_gen
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
async for _ in request_handler.on_message_send_stream(
params, create_server_call_context()
):
pass

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)


@pytest.mark.asyncio
async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_blocking_task'
context_id = 'error_cleanup_blocking_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_blocking',
parts=[],
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
mock_result_aggregator_instance.consume_and_break_on_interrupt.side_effect = ValueError(
'Consumer failed!'
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
await request_handler.on_message_send(
params, create_server_call_context()
)

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)
1 change: 0 additions & 1 deletion tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ async def streaming_coro():

self.assertIsInstance(response.root, JSONRPCErrorResponse)
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
Expand Down
Loading