Skip to content

Commit 091b046

Browse files
committed
refactor: use CallbackPipeline consistently in all callback execution sites
Address bot feedback (round 4) by replacing all manual callback iterations with CallbackPipeline.execute() for consistency and maintainability. Changes (9 locations): 1. base_agent.py: Use CallbackPipeline for before/after agent callbacks 2. callback_pipeline.py: Optimize single plugin callback execution 3. base_llm_flow.py: Use CallbackPipeline for before/after model callbacks 4. functions.py: Use CallbackPipeline for all tool callbacks (async + live) Impact: - Eliminates remaining manual callback iteration logic (~40 lines) - Achieves 100% consistency in callback execution - All sync/async handling and early exit logic centralized - Tests: 24/24 passing - Lint: 9.57/10 (improved from 9.49/10) #non-breaking
1 parent 86da549 commit 091b046

File tree

4 files changed

+42
-57
lines changed

4 files changed

+42
-57
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..utils.feature_decorator import experimental
4646
from .base_agent_config import BaseAgentConfig
4747
from .callback_context import CallbackContext
48+
from .callback_pipeline import CallbackPipeline
4849
from .callback_pipeline import normalize_callbacks
4950

5051
if TYPE_CHECKING:
@@ -441,14 +442,10 @@ async def __handle_before_agent_callback(
441442
# callbacks.
442443
callbacks = normalize_callbacks(self.before_agent_callback)
443444
if not before_agent_callback_content and callbacks:
444-
for callback in callbacks:
445-
before_agent_callback_content = callback(
446-
callback_context=callback_context
447-
)
448-
if inspect.isawaitable(before_agent_callback_content):
449-
before_agent_callback_content = await before_agent_callback_content
450-
if before_agent_callback_content:
451-
break
445+
pipeline = CallbackPipeline(callbacks)
446+
before_agent_callback_content = await pipeline.execute(
447+
callback_context=callback_context
448+
)
452449

453450
# Process the override content if exists, and further process the state
454451
# change if exists.
@@ -499,14 +496,10 @@ async def __handle_after_agent_callback(
499496
# callbacks.
500497
callbacks = normalize_callbacks(self.after_agent_callback)
501498
if not after_agent_callback_content and callbacks:
502-
for callback in callbacks:
503-
after_agent_callback_content = callback(
504-
callback_context=callback_context
505-
)
506-
if inspect.isawaitable(after_agent_callback_content):
507-
after_agent_callback_content = await after_agent_callback_content
508-
if after_agent_callback_content:
509-
break
499+
pipeline = CallbackPipeline(callbacks)
500+
after_agent_callback_content = await pipeline.execute(
501+
callback_context=callback_context
502+
)
510503

511504
# Process the override content if exists, and further process the state
512505
# change if exists.

src/google/adk/agents/callback_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ async def execute_with_plugins(
238238
... )
239239
"""
240240
# Step 1: Execute plugin callback (priority)
241-
result = await CallbackPipeline([plugin_callback]).execute(*args, **kwargs)
241+
result = plugin_callback(*args, **kwargs)
242+
if inspect.isawaitable(result):
243+
result = await result
242244
if result is not None:
243245
return result
244246

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from . import functions
3333
from ...agents.base_agent import BaseAgent
3434
from ...agents.callback_context import CallbackContext
35+
from ...agents.callback_pipeline import CallbackPipeline
3536
from ...agents.callback_pipeline import normalize_callbacks
3637
from ...agents.invocation_context import InvocationContext
3738
from ...agents.live_request_queue import LiveRequestQueue
@@ -810,14 +811,12 @@ async def _handle_before_model_callback(
810811
callbacks = normalize_callbacks(agent.before_model_callback)
811812
if not callbacks:
812813
return
813-
for callback in callbacks:
814-
callback_response = callback(
815-
callback_context=callback_context, llm_request=llm_request
816-
)
817-
if inspect.isawaitable(callback_response):
818-
callback_response = await callback_response
819-
if callback_response:
820-
return callback_response
814+
pipeline = CallbackPipeline(callbacks)
815+
callback_response = await pipeline.execute(
816+
callback_context=callback_context, llm_request=llm_request
817+
)
818+
if callback_response:
819+
return callback_response
821820

822821
async def _handle_after_model_callback(
823822
self,
@@ -868,14 +867,12 @@ async def _maybe_add_grounding_metadata(
868867
callbacks = normalize_callbacks(agent.after_model_callback)
869868
if not callbacks:
870869
return await _maybe_add_grounding_metadata()
871-
for callback in callbacks:
872-
callback_response = callback(
873-
callback_context=callback_context, llm_response=llm_response
874-
)
875-
if inspect.isawaitable(callback_response):
876-
callback_response = await callback_response
877-
if callback_response:
878-
return await _maybe_add_grounding_metadata(callback_response)
870+
pipeline = CallbackPipeline(callbacks)
871+
callback_response = await pipeline.execute(
872+
callback_context=callback_context, llm_response=llm_response
873+
)
874+
if callback_response:
875+
return await _maybe_add_grounding_metadata(callback_response)
879876
return await _maybe_add_grounding_metadata()
880877

881878
def _finalize_model_response_event(

src/google/adk/flows/llm_flows/functions.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.genai import types
3232

3333
from ...agents.active_streaming_tool import ActiveStreamingTool
34+
from ...agents.callback_pipeline import CallbackPipeline
3435
from ...agents.callback_pipeline import normalize_callbacks
3536
from ...agents.invocation_context import InvocationContext
3637
from ...auth.auth_tool import AuthToolArguments
@@ -302,14 +303,12 @@ async def _execute_single_function_call_async(
302303
# Step 2: If no overrides are provided from the plugins, further run the
303304
# canonical callback.
304305
if function_response is None:
305-
for callback in normalize_callbacks(agent.before_tool_callback):
306-
function_response = callback(
306+
callbacks = normalize_callbacks(agent.before_tool_callback)
307+
if callbacks:
308+
pipeline = CallbackPipeline(callbacks)
309+
function_response = await pipeline.execute(
307310
tool=tool, args=function_args, tool_context=tool_context
308311
)
309-
if inspect.isawaitable(function_response):
310-
function_response = await function_response
311-
if function_response:
312-
break
313312

314313
# Step 3: Otherwise, proceed calling the tool normally.
315314
if function_response is None:
@@ -345,17 +344,15 @@ async def _execute_single_function_call_async(
345344
# Step 5: If no overrides are provided from the plugins, further run the
346345
# canonical after_tool_callbacks.
347346
if altered_function_response is None:
348-
for callback in normalize_callbacks(agent.after_tool_callback):
349-
altered_function_response = callback(
347+
callbacks = normalize_callbacks(agent.after_tool_callback)
348+
if callbacks:
349+
pipeline = CallbackPipeline(callbacks)
350+
altered_function_response = await pipeline.execute(
350351
tool=tool,
351352
args=function_args,
352353
tool_context=tool_context,
353354
tool_response=function_response,
354355
)
355-
if inspect.isawaitable(altered_function_response):
356-
altered_function_response = await altered_function_response
357-
if altered_function_response:
358-
break
359356

360357
# Step 6: If alternative response exists from after_tool_callback, use it
361358
# instead of the original function response.
@@ -463,14 +460,12 @@ async def _execute_single_function_call_live(
463460

464461
# Handle before_tool_callbacks - iterate through the canonical callback
465462
# list
466-
for callback in normalize_callbacks(agent.before_tool_callback):
467-
function_response = callback(
463+
callbacks = normalize_callbacks(agent.before_tool_callback)
464+
if callbacks:
465+
pipeline = CallbackPipeline(callbacks)
466+
function_response = await pipeline.execute(
468467
tool=tool, args=function_args, tool_context=tool_context
469468
)
470-
if inspect.isawaitable(function_response):
471-
function_response = await function_response
472-
if function_response:
473-
break
474469

475470
if function_response is None:
476471
function_response = await _process_function_live_helper(
@@ -484,17 +479,15 @@ async def _execute_single_function_call_live(
484479

485480
# Calls after_tool_callback if it exists.
486481
altered_function_response = None
487-
for callback in normalize_callbacks(agent.after_tool_callback):
488-
altered_function_response = callback(
482+
callbacks = normalize_callbacks(agent.after_tool_callback)
483+
if callbacks:
484+
pipeline = CallbackPipeline(callbacks)
485+
altered_function_response = await pipeline.execute(
489486
tool=tool,
490487
args=function_args,
491488
tool_context=tool_context,
492489
tool_response=function_response,
493490
)
494-
if inspect.isawaitable(altered_function_response):
495-
altered_function_response = await altered_function_response
496-
if altered_function_response:
497-
break
498491

499492
if altered_function_response is not None:
500493
function_response = altered_function_response

0 commit comments

Comments
 (0)