Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/google/adk/auth/auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
from ..models.llm_request import LlmRequest
from ..utils.context_utils import Aclosing
from .auth_handler import AuthHandler
from .auth_tool import AuthConfig
from .auth_tool import AuthToolArguments
Expand Down Expand Up @@ -112,7 +113,7 @@ async def run_async(
function_call.id in tools_to_resume
for function_call in function_calls
]):
if function_response_event := await functions.handle_function_calls_async(
if function_response_event_async_gen := functions.handle_function_calls_async_gen(
invocation_context,
event,
{
Expand All @@ -125,7 +126,9 @@ async def run_async(
# auth response would be a dict keyed by function call id
tools_to_resume,
):
yield function_response_event
async with Aclosing(function_response_event_async_gen) as agen:
async for function_response_event in agen:
yield function_response_event
return
return

Expand Down
74 changes: 41 additions & 33 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,43 +683,51 @@ async def _postprocess_handle_function_calls_async(
function_call_event: Event,
llm_request: LlmRequest,
) -> AsyncGenerator[Event, None]:
if function_response_event := await functions.handle_function_calls_async(
if function_response_event_async_gen := functions.handle_function_calls_async_gen(
invocation_context, function_call_event, llm_request.tools_dict
):
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
)
if auth_event:
yield auth_event

tool_confirmation_event = functions.generate_request_confirmation_event(
invocation_context, function_call_event, function_response_event
)
if tool_confirmation_event:
yield tool_confirmation_event

# Always yield the function response event first
yield function_response_event
async with Aclosing(function_response_event_async_gen) as agen:
async for function_response_event in agen:
auth_event = functions.generate_auth_event(
invocation_context, function_response_event
)
if auth_event:
yield auth_event

# Check if this is a set_model_response function response
if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
# Create and yield a final model response event
final_event = (
_output_schema_processor.create_final_model_response_event(
invocation_context, json_response
tool_confirmation_event = (
functions.generate_request_confirmation_event(
invocation_context,
function_call_event,
function_response_event,
)
)
if tool_confirmation_event:
yield tool_confirmation_event

# Always yield the function response event first
yield function_response_event

# Check if this is a set_model_response function response
if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
# Create and yield a final model response event
final_event = (
_output_schema_processor.create_final_model_response_event(
invocation_context, json_response
)
)
)
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async with Aclosing(agent_to_run.run_async(invocation_context)) as agen:
async for event in agen:
yield event
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
invocation_context, transfer_to_agent
)
async with Aclosing(
agent_to_run.run_async(invocation_context)
) as agen:
async for event in agen:
yield event

def _get_agent_to_run(
self, invocation_context: InvocationContext, agent_name: str
Expand Down
178 changes: 132 additions & 46 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import copy
import inspect
import logging
import threading
from typing import Any
from typing import AsyncGenerator
from typing import AsyncIterator
from typing import cast
from typing import Iterator
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
import uuid
Expand All @@ -32,6 +34,7 @@

from ...agents.active_streaming_tool import ActiveStreamingTool
from ...agents.invocation_context import InvocationContext
from ...agents.run_config import StreamingMode
from ...auth.auth_tool import AuthToolArguments
from ...events.event import Event
from ...events.event_actions import EventActions
Expand Down Expand Up @@ -186,16 +189,16 @@ def generate_request_confirmation_event(
)


async def handle_function_calls_async(
def handle_function_calls_async_gen(
invocation_context: InvocationContext,
function_call_event: Event,
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
) -> AsyncGenerator[Optional[Event]]:
"""Calls the functions and returns the function response event."""
function_calls = function_call_event.get_function_calls()
return await handle_function_call_list_async(
return _handle_function_call_list_async_gen(
invocation_context,
function_calls,
tools_dict,
Expand All @@ -204,13 +207,13 @@ async def handle_function_calls_async(
)


async def handle_function_call_list_async(
async def _handle_function_call_list_async_gen(
invocation_context: InvocationContext,
function_calls: list[types.FunctionCall],
tools_dict: dict[str, BaseTool],
filters: Optional[set[str]] = None,
tool_confirmation_dict: Optional[dict[str, ToolConfirmation]] = None,
) -> Optional[Event]:
) -> AsyncGenerator[Optional[Event]]:
"""Calls the functions and returns the function response event."""
from ...agents.llm_agent import LlmAgent

Expand All @@ -222,38 +225,42 @@ async def handle_function_call_list_async(
]

if not filtered_calls:
return None
yield None
return

# Create tasks for parallel execution
tasks = [
asyncio.create_task(
_execute_single_function_call_async(
invocation_context,
function_call,
tools_dict,
agent,
tool_confirmation_dict[function_call.id]
if tool_confirmation_dict
else None,
)
function_call_async_gens = [
_execute_single_function_call_async_gen(
invocation_context,
function_call,
tools_dict,
agent,
tool_confirmation_dict[function_call.id]
if tool_confirmation_dict
else None,
)
for function_call in filtered_calls
]

# Wait for all tasks to complete
function_response_events = await asyncio.gather(*tasks)

# Filter out None results
function_response_events = [
event for event in function_response_events if event is not None
]

if not function_response_events:
return None

merged_event = merge_parallel_function_response_events(
function_response_events
)
merged_event = None
result_events: List[Optional[Event]] = [None] * len(function_call_async_gens)
function_response_events = []
async with Aclosing(
_concat_function_call_generators(function_call_async_gens)
) as agen:
async for idx, event in agen:
result_events[idx] = event
function_response_events = [
event for event in result_events if event is not None
]
if function_response_events:
merged_event = merge_parallel_function_response_events(
function_response_events
)
if invocation_context.run_config.streaming_mode == StreamingMode.SSE:
yield merged_event
if invocation_context.run_config.streaming_mode != StreamingMode.SSE:
yield merged_event

if len(function_response_events) > 1:
# this is needed for debug traces of parallel calls
Expand All @@ -264,16 +271,62 @@ async def handle_function_call_list_async(
response_event_id=merged_event.id,
function_response_event=merged_event,
)
return merged_event


async def _execute_single_function_call_async(
async def _concat_function_call_generators(
gens: List[AsyncGenerator[Any]],
) -> AsyncGenerator[tuple[int, Any]]:
_SENTINEL = object()
q = asyncio.Queue()
gens = list(gens)
n = len(gens)

async def __pump(idx: int, agen_: AsyncGenerator[Any]):
try:
async with Aclosing(agen_) as agen_wrapped:
async for x in agen_wrapped:
await q.put(('ITEM', idx, x))
except Exception as e:
await q.put(('EXC', idx, e))
finally:
aclose = getattr(agen_, 'aclose', None)
if callable(aclose):
try:
await aclose()
except Exception: # noqa: ignore exception when task canceled.
pass

await q.put(('END', idx, _SENTINEL))

tasks = [asyncio.create_task(__pump(i, agen)) for i, agen in enumerate(gens)]
finished = 0
try:
while finished < n:
kind, i, payload = await q.get()
if kind == 'ITEM':
yield i, payload

elif kind == 'EXC':
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise payload

elif kind == 'END':
finished += 1
finally:
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def _execute_single_function_call_async_gen(
invocation_context: InvocationContext,
function_call: types.FunctionCall,
tools_dict: dict[str, BaseTool],
agent: LlmAgent,
tool_confirmation: Optional[ToolConfirmation] = None,
) -> Optional[Event]:
) -> AsyncGenerator[Optional[Event]]:
"""Execute a single function call with thread safety for state modifications."""

async def _run_on_tool_error_callbacks(
Expand Down Expand Up @@ -331,13 +384,14 @@ async def _run_on_tool_error_callbacks(
error=tool_error,
)
if error_response is not None:
return __build_response_event(
yield __build_response_event(
tool, error_response, tool_context, invocation_context
)
return
else:
raise tool_error

async def _run_with_trace():
async def _run_with_trace() -> AsyncGenerator[Optional[Event]]:
nonlocal function_args

# Step 1: Check if plugin before_tool_callback overrides the function
Expand Down Expand Up @@ -366,6 +420,36 @@ async def _run_with_trace():
function_response = await __call_tool_async(
tool, args=function_args, tool_context=tool_context
)
if inspect.isasyncgen(function_response) or isinstance(
function_response, AsyncIterator
):
res = None
async for res in function_response:
if inspect.isawaitable(res):
res = await res
if (
invocation_context.run_config.streaming_mode
== StreamingMode.SSE
):
yield __build_response_event(
tool, res, tool_context, invocation_context
)
function_response = res
elif inspect.isgenerator(function_response) or isinstance(
function_response, Iterator
):
res = None
for res in function_response:
if inspect.isawaitable(res):
res = await res
if (
invocation_context.run_config.streaming_mode
== StreamingMode.SSE
):
yield __build_response_event(
tool, res, tool_context, invocation_context
)
function_response = res
Comment on lines +423 to +452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling asynchronous and synchronous generators is duplicated in these blocks. While separate loops are necessary for async for and for, the body of the loops is identical. Extracting the loop body into a helper function could reduce this duplication and improve maintainability.

except Exception as tool_error:
error_response = await _run_on_tool_error_callbacks(
tool=tool,
Expand Down Expand Up @@ -413,7 +497,8 @@ async def _run_with_trace():
# Allow long running function to return None to not provide function
# response.
if not function_response:
return None
yield None
return

# Note: State deltas are not applied here - they are collected in
# tool_context.actions.state_delta and applied later when the session
Expand All @@ -423,17 +508,18 @@ async def _run_with_trace():
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
return function_response_event
yield function_response_event

with tracer.start_as_current_span(f'execute_tool {tool.name}'):
try:
function_response_event = await _run_with_trace()
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
return function_response_event
async with Aclosing(_run_with_trace()) as agen:
async for function_response_event in agen:
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
yield function_response_event
except:
trace_tool_call(
tool=tool, args=function_args, function_response_event=None
Expand Down
Loading