Skip to content

Commit 2a55584

Browse files
committed
refactor: tool call accept event response
1 parent 530ae04 commit 2a55584

File tree

6 files changed

+354
-347
lines changed

6 files changed

+354
-347
lines changed

src/google/adk/events/event.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from datetime import datetime
1717
import random
1818
import string
19-
from typing import Optional, Dict
19+
from typing import Dict
20+
from typing import Optional
2021

2122
from google.genai import types
2223
from pydantic import alias_generators

src/google/adk/flows/llm_flows/functions.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
from ...tools.base_tool import BaseTool
3939
from ...tools.tool_context import ToolContext
4040

41-
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
42-
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
41+
AF_FUNCTION_CALL_ID_PREFIX = "adk-"
42+
REQUEST_EUC_FUNCTION_CALL_NAME = "adk_request_credential"
4343

44-
logger = logging.getLogger('google_adk.' + __name__)
44+
logger = logging.getLogger("google_adk." + __name__)
4545

4646

4747
def generate_client_function_call_id() -> str:
48-
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
48+
return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}"
4949

5050

5151
def populate_client_function_call_id(model_response_event: Event) -> None:
@@ -100,7 +100,6 @@ def generate_auth_event(
100100
function_call_id,
101101
auth_config,
102102
) in function_response_event.actions.requested_auth_configs.items():
103-
104103
request_euc_function_call = types.FunctionCall(
105104
name=REQUEST_EUC_FUNCTION_CALL_NAME,
106105
args=AuthToolArguments(
@@ -149,7 +148,7 @@ async def handle_function_calls_async(
149148
tools_dict,
150149
)
151150

152-
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
151+
with tracer.start_as_current_span(f"execute_tool {tool.name}"):
153152
# do not use "args" as the variable name, because it is a reserved keyword
154153
# in python debugger.
155154
function_args = function_call.args or {}
@@ -208,7 +207,7 @@ async def handle_function_calls_async(
208207
# this is needed for debug traces of parallel calls
209208
# individual response with tool.name is traced in __build_response_event
210209
# (we drop tool.name from span name here as this is merged event)
211-
with tracer.start_as_current_span('execute_tool (merged)'):
210+
with tracer.start_as_current_span("execute_tool (merged)"):
212211
trace_merged_tool_calls(
213212
response_event_id=merged_event.id,
214213
function_response_event=merged_event,
@@ -232,7 +231,7 @@ async def handle_function_calls_live(
232231
tool, tool_context = _get_tool_and_context(
233232
invocation_context, function_call_event, function_call, tools_dict
234233
)
235-
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
234+
with tracer.start_as_current_span(f"execute_tool {tool.name}"):
236235
# do not use "args" as the variable name, because it is a reserved keyword
237236
# in python debugger.
238237
function_args = function_call.args or {}
@@ -289,7 +288,7 @@ async def handle_function_calls_live(
289288
tool=tool,
290289
args=function_args,
291290
response_event_id=function_response_event.id,
292-
function_response=function_response,
291+
function_response_event=function_response_event,
293292
)
294293
function_response_events.append(function_response_event)
295294

@@ -302,7 +301,7 @@ async def handle_function_calls_live(
302301
# this is needed for debug traces of parallel calls
303302
# individual response with tool.name is traced in __build_response_event
304303
# (we drop tool.name from span name here as this is merged event)
305-
with tracer.start_as_current_span('execute_tool (merged)'):
304+
with tracer.start_as_current_span("execute_tool (merged)"):
306305
trace_merged_tool_calls(
307306
response_event_id=merged_event.id,
308307
function_response_event=merged_event,
@@ -316,10 +315,10 @@ async def _process_function_live_helper(
316315
function_response = None
317316
# Check if this is a stop_streaming function call
318317
if (
319-
function_call.name == 'stop_streaming'
320-
and 'function_name' in function_args
318+
function_call.name == "stop_streaming"
319+
and "function_name" in function_args
321320
):
322-
function_name = function_args['function_name']
321+
function_name = function_args["function_name"]
323322
active_tasks = invocation_context.active_streaming_tools
324323
if (
325324
function_name in active_tasks
@@ -334,29 +333,29 @@ async def _process_function_live_helper(
334333
except (asyncio.CancelledError, asyncio.TimeoutError):
335334
# Log the specific condition
336335
if task.cancelled():
337-
logging.info(f'Task {function_name} was cancelled successfully')
336+
logging.info(f"Task {function_name} was cancelled successfully")
338337
elif task.done():
339-
logging.info(f'Task {function_name} completed during cancellation')
338+
logging.info(f"Task {function_name} completed during cancellation")
340339
else:
341340
logging.warning(
342-
f'Task {function_name} might still be running after'
343-
' cancellation timeout'
341+
f"Task {function_name} might still be running after"
342+
" cancellation timeout"
344343
)
345344
function_response = {
346-
'status': f'The task is not cancelled yet for {function_name}.'
345+
"status": f"The task is not cancelled yet for {function_name}."
347346
}
348347
if not function_response:
349348
# Clean up the reference
350349
active_tasks[function_name].task = None
351350

352351
function_response = {
353-
'status': f'Successfully stopped streaming function {function_name}'
352+
"status": f"Successfully stopped streaming function {function_name}"
354353
}
355354
else:
356355
function_response = {
357-
'status': f'No active streaming function named {function_name} found'
356+
"status": f"No active streaming function named {function_name} found"
358357
}
359-
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
358+
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
360359
# for streaming tool use case
361360
# we require the function to be a async generator function
362361
async def run_tool_and_update_queue(tool, function_args, tool_context):
@@ -368,10 +367,10 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
368367
invocation_context=invocation_context,
369368
):
370369
updated_content = types.Content(
371-
role='user',
370+
role="user",
372371
parts=[
373372
types.Part.from_text(
374-
text=f'Function {tool.name} returned: {result}'
373+
text=f"Function {tool.name} returned: {result}"
375374
)
376375
],
377376
)
@@ -393,9 +392,9 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
393392
# Immediately return a pending response.
394393
# This is required by current live model.
395394
function_response = {
396-
'status': (
397-
'The function is running asynchronously and the results are'
398-
' pending.'
395+
"status": (
396+
"The function is running asynchronously and the results are"
397+
" pending."
399398
)
400399
}
401400
else:
@@ -413,7 +412,7 @@ def _get_tool_and_context(
413412
):
414413
if function_call.name not in tools_dict:
415414
raise ValueError(
416-
f'Function {function_call.name} is not found in the tools_dict.'
415+
f"Function {function_call.name} is not found in the tools_dict."
417416
)
418417

419418
tool_context = ToolContext(
@@ -458,15 +457,15 @@ def __build_response_event(
458457
) -> Event:
459458
# Specs requires the result to be a dict.
460459
if not isinstance(function_result, dict):
461-
function_result = {'result': function_result}
460+
function_result = {"result": function_result}
462461

463462
part_function_response = types.Part.from_function_response(
464463
name=tool.name, response=function_result
465464
)
466465
part_function_response.function_response.id = tool_context.function_call_id
467466

468467
content = types.Content(
469-
role='user',
468+
role="user",
470469
parts=[part_function_response],
471470
)
472471

@@ -482,10 +481,10 @@ def __build_response_event(
482481

483482

484483
def merge_parallel_function_response_events(
485-
function_response_events: list['Event'],
486-
) -> 'Event':
484+
function_response_events: list["Event"],
485+
) -> "Event":
487486
if not function_response_events:
488-
raise ValueError('No function response events provided.')
487+
raise ValueError("No function response events provided.")
489488

490489
if len(function_response_events) == 1:
491490
return function_response_events[0]
@@ -513,7 +512,7 @@ def merge_parallel_function_response_events(
513512
invocation_id=Event.new_id(),
514513
author=base_event.author,
515514
branch=base_event.branch,
516-
content=types.Content(role='user', parts=merged_parts),
515+
content=types.Content(role="user", parts=merged_parts),
517516
actions=merged_actions, # Optionally merge actions if required
518517
)
519518

0 commit comments

Comments
 (0)