Skip to content

Commit e2072af

Browse files
GWealecopybara-github
authored andcommitted
feat: migrate invocation_context to callback_context
Update plugin manager and built-in plugins to prioritize CallbackContext. Keep InvocationContext access for legacy plugins with adapter. Change callback docs/tests to cover the new context. PiperOrigin-RevId: 818798087
1 parent fa84bcb commit e2072af

File tree

10 files changed

+240
-110
lines changed

10 files changed

+240
-110
lines changed

src/google/adk/agents/readonly_context.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ def state(self) -> MappingProxyType[str, Any]:
5555
return MappingProxyType(self._invocation_context.session.state)
5656

5757
@property
58-
def session(self) -> Session:
59-
"""The current session for this invocation."""
60-
return self._invocation_context.session
58+
def user_id(self) -> str:
59+
"""The user ID for the current invocation."""
60+
return self._invocation_context.user_id
61+
62+
@property
63+
def app_name(self) -> str:
64+
"""The application name for the current invocation."""
65+
return self._invocation_context.app_name
66+
67+
@property
68+
def session_id(self) -> str:
69+
"""The session ID for the current invocation."""
70+
return self._invocation_context.session.id
71+
72+
@property
73+
def branch(self) -> Optional[str]:
74+
"""The branch name for the current invocation, if any."""
75+
return self._invocation_context.branch

src/google/adk/cli/plugins/recordings_plugin.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from .recordings_schema import ToolRecording
4040

4141
if TYPE_CHECKING:
42-
from ...agents.invocation_context import InvocationContext
4342
from ...tools.base_tool import BaseTool
4443
from ...tools.tool_context import ToolContext
4544

@@ -75,10 +74,10 @@ def __init__(self, *, name: str = "adk_recordings") -> None:
7574

7675
@override
7776
async def before_run_callback(
78-
self, *, invocation_context: InvocationContext
77+
self, *, callback_context: CallbackContext
7978
) -> Optional[types.Content]:
8079
"""Always create fresh per-invocation recording state when enabled."""
81-
ctx = CallbackContext(invocation_context)
80+
ctx = callback_context
8281
if self._is_record_mode_on(ctx):
8382
# Always create/overwrite the state for this invocation
8483
self._create_invocation_state(ctx)
@@ -280,10 +279,10 @@ async def on_tool_error_callback(
280279

281280
@override
282281
async def after_run_callback(
283-
self, *, invocation_context: InvocationContext
282+
self, *, callback_context: CallbackContext
284283
) -> None:
285284
"""Finalize and persist recordings, then clean per-invocation state."""
286-
ctx = CallbackContext(invocation_context)
285+
ctx = callback_context
287286
if not self._is_record_mode_on(ctx):
288287
return None
289288

src/google/adk/cli/plugins/replay_plugin.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from .recordings_schema import ToolRecording
3939

4040
if TYPE_CHECKING:
41-
from ...agents.invocation_context import InvocationContext
4241
from ...tools.base_tool import BaseTool
4342
from ...tools.tool_context import ToolContext
4443

@@ -81,10 +80,10 @@ def __init__(self, *, name: str = "adk_replay") -> None:
8180

8281
@override
8382
async def before_run_callback(
84-
self, *, invocation_context: InvocationContext
83+
self, *, callback_context: CallbackContext
8584
) -> Optional[types.Content]:
8685
"""Load replay recordings when enabled."""
87-
ctx = CallbackContext(invocation_context)
86+
ctx = callback_context
8887
if self._is_replay_mode_on(ctx):
8988
# Load the replay state for this invocation
9089
self._load_invocation_state(ctx)
@@ -156,10 +155,10 @@ async def before_tool_callback(
156155

157156
@override
158157
async def after_run_callback(
159-
self, *, invocation_context: InvocationContext
158+
self, *, callback_context: CallbackContext
160159
) -> None:
161160
"""Clean up replay state after invocation completes."""
162-
ctx = CallbackContext(invocation_context)
161+
ctx = callback_context
163162
if not self._is_replay_mode_on(ctx):
164163
return None
165164

src/google/adk/plugins/base_plugin.py

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -111,81 +111,125 @@ def __init__(self, name: str):
111111
super().__init__()
112112
self.name = name
113113

114+
if TYPE_CHECKING:
115+
116+
async def on_user_message_callback(
117+
self,
118+
*,
119+
callback_context: Optional[CallbackContext] = None,
120+
user_message: types.Content,
121+
invocation_context: Optional[InvocationContext] = None,
122+
) -> Optional[types.Content]:
123+
"""Callback executed when a user message is received before an invocation starts.
124+
125+
Plugins can implement this with either callback_context (new) or
126+
invocation_context (deprecated) or both.
127+
"""
128+
129+
async def before_run_callback(
130+
self,
131+
*,
132+
callback_context: Optional[CallbackContext] = None,
133+
invocation_context: Optional[InvocationContext] = None,
134+
) -> Optional[types.Content]:
135+
"""Callback executed before the ADK runner runs.
136+
137+
Plugins can implement this with either callback_context (new) or
138+
invocation_context (deprecated) or both.
139+
"""
140+
141+
async def on_event_callback(
142+
self,
143+
*,
144+
callback_context: Optional[CallbackContext] = None,
145+
event: Event,
146+
invocation_context: Optional[InvocationContext] = None,
147+
) -> Optional[Event]:
148+
"""Callback executed after an event is yielded from runner.
149+
150+
Plugins can implement this with either callback_context (new) or
151+
invocation_context (deprecated) or both.
152+
"""
153+
154+
async def after_run_callback(
155+
self,
156+
*,
157+
callback_context: Optional[CallbackContext] = None,
158+
invocation_context: Optional[InvocationContext] = None,
159+
) -> None:
160+
"""Callback executed after an ADK runner run has completed.
161+
162+
Plugins can implement this with either callback_context (new) or
163+
invocation_context (deprecated) or both.
164+
"""
165+
166+
# Runtime implementation accepts both via **kwargs
114167
async def on_user_message_callback(
115-
self,
116-
*,
117-
invocation_context: InvocationContext,
118-
user_message: types.Content,
168+
self, **kwargs: Any
119169
) -> Optional[types.Content]:
120170
"""Callback executed when a user message is received before an invocation starts.
121171
122172
This callback helps logging and modifying the user message before the
123173
runner starts the invocation.
124174
125175
Args:
126-
invocation_context: The context for the entire invocation.
176+
callback_context: The context for the callback execution.
127177
user_message: The message content input by user.
178+
invocation_context: DEPRECATED. Use callback_context instead. The context
179+
for the entire invocation. This parameter is maintained for backward
180+
compatibility and will be removed in a future version.
128181
129182
Returns:
130-
An optional `types.Content` to be returned to the ADK. Returning a
131-
value to replace the user message. Returning `None` to proceed
132-
normally.
183+
The modified user message or None if no modification is needed.
133184
"""
134-
pass
185+
return None
135186

136-
async def before_run_callback(
137-
self, *, invocation_context: InvocationContext
138-
) -> Optional[types.Content]:
187+
async def before_run_callback(self, **kwargs: Any) -> Optional[types.Content]:
139188
"""Callback executed before the ADK runner runs.
140189
141-
This is the first callback to be called in the lifecycle, ideal for global
142-
setup or initialization tasks.
190+
This is the first lifecycle hook and is ideal for global setup, logging,
191+
or checks that may stop the invocation from running.
143192
144193
Args:
145-
invocation_context: The context for the entire invocation, containing
146-
session information, the root agent, etc.
194+
callback_context: The context for the callback execution.
195+
invocation_context: DEPRECATED. Use callback_context instead. The context
196+
for the entire invocation. This parameter is maintained for backward
197+
compatibility and will be removed in a future version.
147198
148199
Returns:
149-
An optional `Event` to be returned to the ADK. Returning a value to
150-
halt execution of the runner and ends the runner with that event. Return
151-
`None` to proceed normally.
200+
Optional `types.Content` to halt execution and return the value to the
201+
caller. Return `None` to proceed normally.
152202
"""
153-
pass
203+
return None
154204

155-
async def on_event_callback(
156-
self, *, invocation_context: InvocationContext, event: Event
157-
) -> Optional[Event]:
205+
async def on_event_callback(self, **kwargs: Any) -> Optional[Event]:
158206
"""Callback executed after an event is yielded from runner.
159207
160-
This is the ideal place to make modification to the event before the event
161-
is handled by the underlying agent app.
162-
163208
Args:
164-
invocation_context: The context for the entire invocation.
209+
callback_context: The context for the callback execution.
165210
event: The event raised by the runner.
211+
invocation_context: DEPRECATED. Use callback_context instead. The context
212+
for the entire invocation. This parameter is maintained for backward
213+
compatibility and will be removed in a future version.
166214
167215
Returns:
168-
An optional value. A non-`None` return may be used by the framework to
169-
modify or replace the response. Returning `None` allows the original
170-
response to be used.
216+
The modified event or None if no modification is needed.
171217
"""
172-
pass
218+
return None
173219

174-
async def after_run_callback(
175-
self, *, invocation_context: InvocationContext
176-
) -> None:
220+
async def after_run_callback(self, **kwargs: Any) -> None:
177221
"""Callback executed after an ADK runner run has completed.
178222
179-
This is the final callback in the ADK lifecycle, suitable for cleanup, final
180-
logging, or reporting tasks.
181-
182223
Args:
183-
invocation_context: The context for the entire invocation.
224+
callback_context: The context for the callback execution.
225+
invocation_context: DEPRECATED. Use callback_context instead. The context
226+
for the entire invocation. This parameter is maintained for backward
227+
compatibility and will be removed in a future version.
184228
185229
Returns:
186230
None
187231
"""
188-
pass
232+
return None
189233

190234
async def before_agent_callback(
191235
self, *, agent: BaseAgent, callback_context: CallbackContext

src/google/adk/plugins/global_instruction_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def before_model_callback(
7979
return None
8080

8181
# Resolve the global instruction (handle both string and InstructionProvider)
82-
readonly_context = ReadonlyContext(callback_context.invocation_context)
82+
readonly_context = callback_context
8383
final_global_instruction = await self._resolve_global_instruction(
8484
readonly_context
8585
)

src/google/adk/plugins/logging_plugin.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,32 @@ def __init__(self, name: str = "logging_plugin"):
6969
async def on_user_message_callback(
7070
self,
7171
*,
72-
invocation_context: InvocationContext,
72+
callback_context: CallbackContext,
7373
user_message: types.Content,
7474
) -> Optional[types.Content]:
7575
"""Log user message and invocation start."""
7676
self._log(f"🚀 USER MESSAGE RECEIVED")
77-
self._log(f" Invocation ID: {invocation_context.invocation_id}")
78-
self._log(f" Session ID: {invocation_context.session.id}")
79-
self._log(f" User ID: {invocation_context.user_id}")
80-
self._log(f" App Name: {invocation_context.app_name}")
81-
self._log(
82-
" Root Agent:"
83-
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
84-
)
77+
self._log(f" Invocation ID: {callback_context.invocation_id}")
78+
self._log(f" Session ID: {callback_context.session_id}")
79+
self._log(f" User ID: {callback_context.user_id}")
80+
self._log(f" App Name: {callback_context.app_name}")
81+
self._log(f" Root Agent: {callback_context.agent_name}")
8582
self._log(f" User Content: {self._format_content(user_message)}")
86-
if invocation_context.branch:
87-
self._log(f" Branch: {invocation_context.branch}")
83+
if callback_context.branch:
84+
self._log(f" Branch: {callback_context.branch}")
8885
return None
8986

9087
async def before_run_callback(
91-
self, *, invocation_context: InvocationContext
88+
self, *, callback_context: CallbackContext
9289
) -> Optional[types.Content]:
9390
"""Log invocation start."""
9491
self._log(f"🏃 INVOCATION STARTING")
95-
self._log(f" Invocation ID: {invocation_context.invocation_id}")
96-
self._log(
97-
" Starting Agent:"
98-
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
99-
)
92+
self._log(f" Invocation ID: {callback_context.invocation_id}")
93+
self._log(f" Starting Agent: {callback_context.agent_name}")
10094
return None
10195

10296
async def on_event_callback(
103-
self, *, invocation_context: InvocationContext, event: Event
97+
self, *, callback_context: CallbackContext, event: Event
10498
) -> Optional[Event]:
10599
"""Log events yielded from the runner."""
106100
self._log(f"📢 EVENT YIELDED")
@@ -123,15 +117,12 @@ async def on_event_callback(
123117
return None
124118

125119
async def after_run_callback(
126-
self, *, invocation_context: InvocationContext
120+
self, *, callback_context: CallbackContext
127121
) -> Optional[None]:
128122
"""Log invocation completion."""
129123
self._log(f"✅ INVOCATION COMPLETED")
130-
self._log(f" Invocation ID: {invocation_context.invocation_id}")
131-
self._log(
132-
" Final Agent:"
133-
f" {invocation_context.agent.name if hasattr(invocation_context.agent, 'name') else 'Unknown'}"
134-
)
124+
self._log(f" Invocation ID: {callback_context.invocation_id}")
125+
self._log(f" Final Agent: {callback_context.agent_name}")
135126
return None
136127

137128
async def before_agent_callback(
@@ -141,8 +132,8 @@ async def before_agent_callback(
141132
self._log(f"🤖 AGENT STARTING")
142133
self._log(f" Agent Name: {callback_context.agent_name}")
143134
self._log(f" Invocation ID: {callback_context.invocation_id}")
144-
if callback_context._invocation_context.branch:
145-
self._log(f" Branch: {callback_context._invocation_context.branch}")
135+
if callback_context.branch:
136+
self._log(f" Branch: {callback_context.branch}")
146137
return None
147138

148139
async def after_agent_callback(

0 commit comments

Comments
 (0)