Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,15 @@ async def _preprocess_async(

# If it's a toolset, process it first
if isinstance(tool_union, BaseToolset):
# Generate preprocessing events (e.g., authentication requests)
async with Aclosing(
tool_union.generate_preprocessing_events(
tool_context=tool_context, llm_request=llm_request
)
) as agen:
async for event in agen:
yield event

await tool_union.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)
Expand Down
30 changes: 30 additions & 0 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC
from abc import abstractmethod
import copy
from typing import AsyncGenerator
from typing import final
from typing import List
from typing import Optional
Expand All @@ -31,6 +32,7 @@
from .base_tool import BaseTool

if TYPE_CHECKING:
from ..events.event import Event
from ..models.llm_request import LlmRequest
from .tool_configs import ToolArgsConfig
from .tool_context import ToolContext
Expand Down Expand Up @@ -204,3 +206,31 @@ async def process_llm_request(
llm_request: The outgoing LLM request, mutable this method.
"""
pass

async def generate_preprocessing_events(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
"""Generates events during the preprocessing phase.

This method allows toolsets to generate events (such as authentication
requests) before tool discovery occurs. It has access to the full
ToolContext with authentication capabilities.

Use cases:
- OAuth2 authentication flows before tool discovery
- User confirmation requests for sensitive toolsets
- Dynamic configuration based on user context
- Pre-flight checks that require user interaction

Args:
tool_context: The context of the tool with full authentication capabilities.
llm_request: The outgoing LLM request, mutable by this method.

Yields:
Event: Events for user interaction (e.g., authentication requests).
"""
# Default implementation yields nothing (backward compatibility)
# Subclasses can override to yield authentication or other events
if False: # This ensures the method is an AsyncGenerator
yield # Required for AsyncGenerator type hint
return
Comment on lines +234 to +236

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if False: yield pattern to create an empty async generator is a bit of a workaround. A more modern and explicit approach is to use yield from (). This is more readable, idiomatic, and achieves the same result of creating an empty async generator.

    yield from ()

152 changes: 152 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Unit tests for BaseLlmFlow toolset integration."""

from typing import AsyncGenerator
from unittest import mock
from unittest.mock import AsyncMock

Expand All @@ -26,6 +27,7 @@
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.base_toolset import BaseToolset
from google.adk.tools.google_search_tool import GoogleSearchTool
from google.adk.tools.tool_context import ToolContext
from google.genai import types
import pytest

Expand Down Expand Up @@ -91,6 +93,156 @@ async def close(self):
assert mock_toolset.process_llm_request_called


@pytest.mark.asyncio
async def test_preprocess_calls_toolset_generate_preprocessing_events():
"""Test that _preprocess_async calls generate_preprocessing_events on toolsets."""

# Create a mock toolset that tracks if generate_preprocessing_events was called
class _MockToolset(BaseToolset):

def __init__(self):
super().__init__()
self.generate_preprocessing_events_called = False
self.generated_events = []

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self.generated_events list is initialized here, but it's never used or asserted against in the test. It appears to be redundant and can be removed for clarity. Please also remove the line that appends to it (line 121).


async def generate_preprocessing_events(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
self.generate_preprocessing_events_called = True
# Generate a mock authentication event
auth_event = Event(
author='system',
invocation_id='test_invocation',
content=types.Content(
role='model',
parts=[types.Part(text='Mock authentication request')],
),
)
self.generated_events.append(auth_event)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line appends to the self.generated_events list, which is unused. It can be removed along with the list's initialization on line 106.

yield auth_event

async def get_tools(self, readonly_context=None):
return []

async def close(self):
pass

mock_toolset = _MockToolset()

# Create a mock model that returns a simple response
mock_response = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Test response')]
),
partial=False,
)

mock_model = testing_utils.MockModel.create(responses=[mock_response])

# Create agent with the mock toolset
agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset])
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

flow = BaseLlmFlowForTesting()

# Call _preprocess_async
llm_request = LlmRequest()
events = []
async for event in flow._preprocess_async(invocation_context, llm_request):
events.append(event)

# Verify that generate_preprocessing_events was called on the toolset
assert mock_toolset.generate_preprocessing_events_called

# Verify that the generated event was yielded
assert len(events) == 1
assert events[0].author == 'system'
assert events[0].content.parts[0].text == 'Mock authentication request'


@pytest.mark.asyncio
async def test_preprocess_calls_both_generate_events_and_process_request():
"""Test that _preprocess_async calls both generate_preprocessing_events and process_llm_request."""

# Create a mock toolset that tracks both method calls
class _MockToolset(BaseToolset):

def __init__(self):
super().__init__()
self.generate_preprocessing_events_called = False
self.process_llm_request_called = False
self.call_order = []

async def generate_preprocessing_events(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
self.generate_preprocessing_events_called = True
self.call_order.append('generate_preprocessing_events')
# Generate a mock event
yield Event(
author='system',
invocation_id='test_invocation',
content=types.Content(
role='model', parts=[types.Part(text='Mock event')]
),
)

async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
self.process_llm_request_called = True
self.call_order.append('process_llm_request')

async def get_tools(self, readonly_context=None):
return []

async def close(self):
pass

mock_toolset = _MockToolset()

# Create a mock model that returns a simple response
mock_response = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Test response')]
),
partial=False,
)

mock_model = testing_utils.MockModel.create(responses=[mock_response])

# Create agent with the mock toolset
agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset])
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

flow = BaseLlmFlowForTesting()

# Call _preprocess_async
llm_request = LlmRequest()
events = []
async for event in flow._preprocess_async(invocation_context, llm_request):
events.append(event)

# Verify that both methods were called
assert mock_toolset.generate_preprocessing_events_called
assert mock_toolset.process_llm_request_called

# Verify the correct call order (generate_preprocessing_events first)
assert mock_toolset.call_order == [
'generate_preprocessing_events',
'process_llm_request',
]

# Verify that the generated event was yielded
assert len(events) == 1
assert events[0].author == 'system'
assert events[0].content.parts[0].text == 'Mock event'


@pytest.mark.asyncio
async def test_preprocess_handles_mixed_tools_and_toolsets():
"""Test that _preprocess_async properly handles both tools and toolsets."""
Expand Down