Skip to content
Open
125 changes: 110 additions & 15 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import asyncio
from datetime import datetime
from datetime import timezone
import inspect
Expand Down Expand Up @@ -87,9 +88,16 @@ def __init__(
super().__init__()
self._runner = runner
self._config = config or A2aAgentExecutorConfig()
# Track active tasks by task_id for cancellation support
self._active_tasks: dict[str, asyncio.Task] = {}
# Lock to protect _active_tasks from race conditions
self._tasks_lock = asyncio.Lock()

async def _resolve_runner(self) -> Runner:
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
"""Resolve the runner.

Handles cases where it's a callable returning a Runner.
"""
# If already resolved and cached, return it
if isinstance(self._runner, Runner):
return self._runner
Expand All @@ -114,9 +122,70 @@ async def _resolve_runner(self) -> Runner:

@override
async def cancel(self, context: RequestContext, event_queue: EventQueue):
"""Cancel the execution."""
# TODO: Implement proper cancellation logic if needed
raise NotImplementedError('Cancellation is not supported')
"""Cancel the execution of a running task.

Args:
context: The request context containing the task_id to cancel.
event_queue: The event queue to publish cancellation events to.

If the task is found and running, it will be cancelled and a cancellation
event will be published. If the task is not found or already completed,
the method will log a warning and return gracefully.
"""
if not context.task_id:
logger.warning('Cannot cancel task: no task_id provided in context')
return

# Use lock to prevent race conditions with _handle_request cleanup
async with self._tasks_lock:
task = self._active_tasks.pop(context.task_id, None)

if not task:
logger.warning(
'Task %s not found or already completed', context.task_id
)
return

if task.done():
logger.info('Task %s already completed', context.task_id)
return

# Cancel the task (outside lock to avoid blocking other operations)
logger.info('Cancelling task %s', context.task_id)
if not task.cancel():
# Task completed before it could be cancelled
logger.info(
'Task %s completed before it could be cancelled', context.task_id
)
return

try:
# Wait for cancellation to complete with timeout
await asyncio.wait_for(task, timeout=1.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
# Expected when task is cancelled or timeout occurs
pass

# Publish cancellation event
try:
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=context.task_id,
status=TaskStatus(
state=TaskState.failed,
timestamp=datetime.now(timezone.utc).isoformat(),
message=Message(
message_id=str(uuid.uuid4()),
role=Role.agent,
parts=[TextPart(text='Task was cancelled')],
),
),
context_id=context.context_id,
final=True,
)
)
except Exception as e:
logger.error('Failed to publish cancellation event: %s', e, exc_info=True)

@override
async def execute(
Expand Down Expand Up @@ -221,17 +290,43 @@ async def _handle_request(
)

task_result_aggregator = TaskResultAggregator()
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
async for adk_event in agen:
for a2a_event in self._config.event_converter(
adk_event,
invocation_context,
context.task_id,
context.context_id,
self._config.gen_ai_part_converter,
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)

# Helper function to iterate over async generator
async def _process_events():
async with Aclosing(runner.run_async(**vars(run_request))) as agen:
async for adk_event in agen:
for a2a_event in self._config.event_converter(
adk_event,
invocation_context,
context.task_id,
context.context_id,
self._config.gen_ai_part_converter,
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)

# Create and track the task for cancellation support
if context.task_id:
task = asyncio.create_task(_process_events())
# Use lock to prevent race conditions with cancel()
async with self._tasks_lock:
self._active_tasks[context.task_id] = task
try:
await task
except asyncio.CancelledError:
# Task was cancelled
# Note: cancellation event is published by cancel() method,
# so we just log and handle gracefully here
logger.info('Task %s was cancelled', context.task_id)
# Return early - don't publish completion events for cancelled tasks
return
finally:
# Clean up task tracking (use lock to prevent race conditions)
async with self._tasks_lock:
self._active_tasks.pop(context.task_id, None)
else:
# No task_id, run without tracking
await _process_events()

# publish the task result event - this is final
if (
Expand Down
176 changes: 166 additions & 10 deletions tests/unittests/a2a/executor/test_a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
Expand All @@ -20,6 +21,7 @@
from a2a.server.events.event_queue import EventQueue
from a2a.types import Message
from a2a.types import TaskState
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from google.adk.a2a.converters.request_converter import AgentRunRequest
from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor
Expand Down Expand Up @@ -583,22 +585,176 @@ async def test_cancel_with_task_id(self):
"""Test cancellation with a task ID."""
self.mock_context.task_id = "test-task-id"

# The current implementation raises NotImplementedError
with pytest.raises(
NotImplementedError, match="Cancellation is not supported"
):
await self.executor.cancel(self.mock_context, self.mock_event_queue)
# Cancel should succeed without raising
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# If no task is running, should log warning but not raise
# Verify event queue was not called (no task to cancel)
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_without_task_id(self):
"""Test cancellation without a task ID."""
self.mock_context.task_id = None

# The current implementation raises NotImplementedError regardless of task_id
with pytest.raises(
NotImplementedError, match="Cancellation is not supported"
):
await self.executor.cancel(self.mock_context, self.mock_event_queue)
# Cancel should handle missing task_id gracefully
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish any events when task_id is missing
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_running_task(self):
"""Test cancellation of a running task."""
self.mock_context.task_id = "test-task-id"

# Setup: Create a running task by starting execution
self.mock_request_converter.return_value = AgentRunRequest(
user_id="test-user",
session_id="test-session",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
mock_invocation_context = Mock()
self.mock_runner._new_invocation_context.return_value = (
mock_invocation_context
)

# Create an async generator that yields events slowly
async def slow_generator():
mock_event = Mock(spec=Event)
yield mock_event
# This will hang if not cancelled
await asyncio.sleep(10)

# Replace run_async with the async generator function
self.mock_runner.run_async = slow_generator
self.mock_event_converter.return_value = []

# Start execution in background
execute_task = asyncio.create_task(
self.executor.execute(self.mock_context, self.mock_event_queue)
)

# Wait a bit to ensure task is running
await asyncio.sleep(0.1)

# Cancel the task
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Wait for cancellation to complete
try:
await asyncio.wait_for(execute_task, timeout=2.0)
except asyncio.CancelledError:
pass

# Verify cancellation event was published
assert self.mock_event_queue.enqueue_event.call_count > 0
# Find the cancellation event (should be the last one with failed state)
cancellation_events = [
call[0][0]
for call in self.mock_event_queue.enqueue_event.call_args_list
if isinstance(call[0][0], TaskStatusUpdateEvent)
and call[0][0].status.state == TaskState.failed
and call[0][0].final is True
]
assert len(cancellation_events) > 0, "No cancellation event found"
cancellation_event = cancellation_events[-1]
assert cancellation_event.status.state == TaskState.failed
assert cancellation_event.final is True

@pytest.mark.asyncio
async def test_cancel_nonexistent_task(self):
"""Test cancellation of a non-existent task."""
self.mock_context.task_id = "nonexistent-task-id"

# Cancel should handle gracefully
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish any events for non-existent task
assert self.mock_event_queue.enqueue_event.call_count == 0

@pytest.mark.asyncio
async def test_cancel_completed_task(self):
"""Test cancellation of an already completed task."""
self.mock_context.task_id = "test-task-id"

# Setup and run a task to completion
self.mock_request_converter.return_value = AgentRunRequest(
user_id="test-user",
session_id="test-session",
new_message=Mock(spec=Content),
run_config=Mock(spec=RunConfig),
)
mock_session = Mock()
mock_session.id = "test-session"
self.mock_runner.session_service.get_session = AsyncMock(
return_value=mock_session
)
mock_invocation_context = Mock()
self.mock_runner._new_invocation_context.return_value = (
mock_invocation_context
)

# Create a generator that completes immediately
async def quick_generator():
mock_event = Mock(spec=Event)
yield mock_event

self.mock_runner.run_async.return_value = quick_generator()
self.mock_event_converter.return_value = []

# Run to completion
await self.executor.execute(self.mock_context, self.mock_event_queue)

# Now try to cancel (should handle gracefully)
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Should not publish additional cancellation event for completed task
# (The execute already published final event)

@pytest.mark.asyncio
async def test_cancel_race_condition_task_completes_before_cancel(self):
"""Test race condition where task completes before cancel() is called."""
self.mock_context.task_id = "test-task-id"

# Create a mock task that is already done
mock_task = Mock(spec=asyncio.Task)
mock_task.done.return_value = False # Initially not done (passes check)
mock_task.cancel.return_value = (
False # Returns False because task completed between check and cancel
)

# Manually add task to _active_tasks to simulate race condition
self.executor._active_tasks["test-task-id"] = mock_task

# Call cancel
await self.executor.cancel(self.mock_context, self.mock_event_queue)

# Verify task.cancel() was called
mock_task.cancel.assert_called_once()

# Verify no cancellation event was published (since cancel() returned False)
# Check that no TaskStatusUpdateEvent with "Task was cancelled" was published
cancellation_events = [
call[0][0]
for call in self.mock_event_queue.enqueue_event.call_args_list
if isinstance(call[0][0], TaskStatusUpdateEvent)
and call[0][0].status.state == TaskState.failed
and any(
part.text == "Task was cancelled"
for part in call[0][0].status.message.parts
if hasattr(part, "text")
)
]
assert (
len(cancellation_events) == 0
), "Should not publish cancellation event when task completed before cancel"

@pytest.mark.asyncio
async def test_execute_with_exception_handling(self):
Expand Down