Skip to content

Commit 0f38f17

Browse files
committed
refactor: use CallbackPipeline consistently in all callback execution sites
Address bot feedback (round 4) by replacing all manual callback iterations with CallbackPipeline.execute() for consistency and maintainability. Changes (9 locations): 1. base_agent.py: Use CallbackPipeline for before/after agent callbacks 2. callback_pipeline.py: Optimize single plugin callback execution 3. base_llm_flow.py: Use CallbackPipeline for before/after model callbacks 4. functions.py: Use CallbackPipeline for all tool callbacks (async + live) Impact: - Eliminates remaining manual callback iteration logic (~40 lines) - Achieves 100% consistency in callback execution - All sync/async handling and early exit logic centralized - Tests: 24/24 passing - Lint: 9.57/10 (improved from 9.49/10) #non-breaking
1 parent cd3416e commit 0f38f17

File tree

4 files changed

+42
-57
lines changed

4 files changed

+42
-57
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..utils.feature_decorator import experimental
4646
from .base_agent_config import BaseAgentConfig
4747
from .callback_context import CallbackContext
48+
from .callback_pipeline import CallbackPipeline
4849
from .callback_pipeline import normalize_callbacks
4950

5051
if TYPE_CHECKING:
@@ -429,14 +430,10 @@ async def _handle_before_agent_callback(
429430
# callbacks.
430431
callbacks = normalize_callbacks(self.before_agent_callback)
431432
if not before_agent_callback_content and callbacks:
432-
for callback in callbacks:
433-
before_agent_callback_content = callback(
434-
callback_context=callback_context
435-
)
436-
if inspect.isawaitable(before_agent_callback_content):
437-
before_agent_callback_content = await before_agent_callback_content
438-
if before_agent_callback_content:
439-
break
433+
pipeline = CallbackPipeline(callbacks)
434+
before_agent_callback_content = await pipeline.execute(
435+
callback_context=callback_context
436+
)
440437

441438
# Process the override content if exists, and further process the state
442439
# change if exists.
@@ -487,14 +484,10 @@ async def _handle_after_agent_callback(
487484
# callbacks.
488485
callbacks = normalize_callbacks(self.after_agent_callback)
489486
if not after_agent_callback_content and callbacks:
490-
for callback in callbacks:
491-
after_agent_callback_content = callback(
492-
callback_context=callback_context
493-
)
494-
if inspect.isawaitable(after_agent_callback_content):
495-
after_agent_callback_content = await after_agent_callback_content
496-
if after_agent_callback_content:
497-
break
487+
pipeline = CallbackPipeline(callbacks)
488+
after_agent_callback_content = await pipeline.execute(
489+
callback_context=callback_context
490+
)
498491

499492
# Process the override content if exists, and further process the state
500493
# change if exists.

src/google/adk/agents/callback_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ async def execute_with_plugins(
238238
... )
239239
"""
240240
# Step 1: Execute plugin callback (priority)
241-
result = await CallbackPipeline([plugin_callback]).execute(*args, **kwargs)
241+
result = plugin_callback(*args, **kwargs)
242+
if inspect.isawaitable(result):
243+
result = await result
242244
if result is not None:
243245
return result
244246

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from . import functions
3333
from ...agents.base_agent import BaseAgent
3434
from ...agents.callback_context import CallbackContext
35+
from ...agents.callback_pipeline import CallbackPipeline
3536
from ...agents.callback_pipeline import normalize_callbacks
3637
from ...agents.invocation_context import InvocationContext
3738
from ...agents.live_request_queue import LiveRequestQueue
@@ -819,14 +820,12 @@ async def _handle_before_model_callback(
819820
callbacks = normalize_callbacks(agent.before_model_callback)
820821
if not callbacks:
821822
return
822-
for callback in callbacks:
823-
callback_response = callback(
824-
callback_context=callback_context, llm_request=llm_request
825-
)
826-
if inspect.isawaitable(callback_response):
827-
callback_response = await callback_response
828-
if callback_response:
829-
return callback_response
823+
pipeline = CallbackPipeline(callbacks)
824+
callback_response = await pipeline.execute(
825+
callback_context=callback_context, llm_request=llm_request
826+
)
827+
if callback_response:
828+
return callback_response
830829

831830
async def _handle_after_model_callback(
832831
self,
@@ -877,14 +876,12 @@ async def _maybe_add_grounding_metadata(
877876
callbacks = normalize_callbacks(agent.after_model_callback)
878877
if not callbacks:
879878
return await _maybe_add_grounding_metadata()
880-
for callback in callbacks:
881-
callback_response = callback(
882-
callback_context=callback_context, llm_response=llm_response
883-
)
884-
if inspect.isawaitable(callback_response):
885-
callback_response = await callback_response
886-
if callback_response:
887-
return await _maybe_add_grounding_metadata(callback_response)
879+
pipeline = CallbackPipeline(callbacks)
880+
callback_response = await pipeline.execute(
881+
callback_context=callback_context, llm_response=llm_response
882+
)
883+
if callback_response:
884+
return await _maybe_add_grounding_metadata(callback_response)
888885
return await _maybe_add_grounding_metadata()
889886

890887
def _finalize_model_response_event(

src/google/adk/flows/llm_flows/functions.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.genai import types
3232

3333
from ...agents.active_streaming_tool import ActiveStreamingTool
34+
from ...agents.callback_pipeline import CallbackPipeline
3435
from ...agents.callback_pipeline import normalize_callbacks
3536
from ...agents.invocation_context import InvocationContext
3637
from ...auth.auth_tool import AuthToolArguments
@@ -318,14 +319,12 @@ async def _execute_single_function_call_async(
318319
# Step 2: If no overrides are provided from the plugins, further run the
319320
# canonical callback.
320321
if function_response is None:
321-
for callback in normalize_callbacks(agent.before_tool_callback):
322-
function_response = callback(
322+
callbacks = normalize_callbacks(agent.before_tool_callback)
323+
if callbacks:
324+
pipeline = CallbackPipeline(callbacks)
325+
function_response = await pipeline.execute(
323326
tool=tool, args=function_args, tool_context=tool_context
324327
)
325-
if inspect.isawaitable(function_response):
326-
function_response = await function_response
327-
if function_response:
328-
break
329328

330329
# Step 3: Otherwise, proceed calling the tool normally.
331330
if function_response is None:
@@ -361,17 +360,15 @@ async def _execute_single_function_call_async(
361360
# Step 5: If no overrides are provided from the plugins, further run the
362361
# canonical after_tool_callbacks.
363362
if altered_function_response is None:
364-
for callback in normalize_callbacks(agent.after_tool_callback):
365-
altered_function_response = callback(
363+
callbacks = normalize_callbacks(agent.after_tool_callback)
364+
if callbacks:
365+
pipeline = CallbackPipeline(callbacks)
366+
altered_function_response = await pipeline.execute(
366367
tool=tool,
367368
args=function_args,
368369
tool_context=tool_context,
369370
tool_response=function_response,
370371
)
371-
if inspect.isawaitable(altered_function_response):
372-
altered_function_response = await altered_function_response
373-
if altered_function_response:
374-
break
375372

376373
# Step 6: If alternative response exists from after_tool_callback, use it
377374
# instead of the original function response.
@@ -479,14 +476,12 @@ async def _execute_single_function_call_live(
479476

480477
# Handle before_tool_callbacks - iterate through the canonical callback
481478
# list
482-
for callback in normalize_callbacks(agent.before_tool_callback):
483-
function_response = callback(
479+
callbacks = normalize_callbacks(agent.before_tool_callback)
480+
if callbacks:
481+
pipeline = CallbackPipeline(callbacks)
482+
function_response = await pipeline.execute(
484483
tool=tool, args=function_args, tool_context=tool_context
485484
)
486-
if inspect.isawaitable(function_response):
487-
function_response = await function_response
488-
if function_response:
489-
break
490485

491486
if function_response is None:
492487
function_response = await _process_function_live_helper(
@@ -500,17 +495,15 @@ async def _execute_single_function_call_live(
500495

501496
# Calls after_tool_callback if it exists.
502497
altered_function_response = None
503-
for callback in normalize_callbacks(agent.after_tool_callback):
504-
altered_function_response = callback(
498+
callbacks = normalize_callbacks(agent.after_tool_callback)
499+
if callbacks:
500+
pipeline = CallbackPipeline(callbacks)
501+
altered_function_response = await pipeline.execute(
505502
tool=tool,
506503
args=function_args,
507504
tool_context=tool_context,
508505
tool_response=function_response,
509506
)
510-
if inspect.isawaitable(altered_function_response):
511-
altered_function_response = await altered_function_response
512-
if altered_function_response:
513-
break
514507

515508
if altered_function_response is not None:
516509
function_response = altered_function_response

0 commit comments

Comments
 (0)