Skip to content

Commit f6a2347

Browse files
committed
Address review comments
1 parent a0f8f38 commit f6a2347

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

src/google/adk/sessions/base_session_service.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,11 @@ async def append_event(self, session: Session, event: Event) -> Event:
106106
"""Appends an event to a session object."""
107107
if event.partial:
108108
return event
109-
event = self._trim_temp_delta_state(event)
109+
# Update session state with ALL keys (including temp:) so they're accessible
110+
# during callbacks within the same invocation
110111
self._update_session_state(session, event)
112+
# Trim temp: keys from the event before persisting to avoid storing them
113+
event = self._trim_temp_delta_state(event)
111114
session.events.append(event)
112115
return event
113116

@@ -127,5 +130,4 @@ def _update_session_state(self, session: Session, event: Event) -> None:
127130
"""Updates the session state based on the event."""
128131
if not event.actions or not event.actions.state_delta:
129132
return
130-
for key, value in event.actions.state_delta.items():
131-
session.state.update({key: value})
133+
session.state.update(event.actions.state_delta)

tests/unittests/test_runners.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,15 +1287,14 @@ async def after_agent_callback(self, *, agent, callback_context):
12871287
)
12881288

12891289
# Run the agent
1290-
events = []
1291-
async for event in runner.run_async(
1290+
async for _ in runner.run_async(
12921291
user_id=TEST_USER_ID,
12931292
session_id=TEST_SESSION_ID,
12941293
new_message=types.Content(
12951294
role="user", parts=[types.Part(text="test message")]
12961295
),
12971296
):
1298-
events.append(event)
1297+
pass
12991298

13001299
# Verify temp state was accessible during callbacks
13011300
assert state_seen_in_before_agent["temp:test_key"] == "test_value"
@@ -1322,5 +1321,78 @@ async def after_agent_callback(self, *, agent, callback_context):
13221321
assert "temp:test_key" not in event.actions.state_delta
13231322

13241323

1324+
@pytest.mark.asyncio
1325+
async def test_temp_state_from_state_delta_accessible_in_callbacks():
1326+
"""Tests that temp: state set via run_async state_delta parameter is
1327+
accessible during lifecycle callbacks but not persisted."""
1328+
1329+
# Track what state was seen during callbacks
1330+
state_seen_in_before_agent = {}
1331+
1332+
class StateAccessPlugin(BasePlugin):
1333+
"""Plugin that accesses state during callbacks."""
1334+
1335+
async def before_agent_callback(self, *, agent, callback_context):
1336+
# Check if temp state from state_delta is accessible
1337+
state_seen_in_before_agent["temp:from_run_async"] = (
1338+
callback_context.state.get("temp:from_run_async")
1339+
)
1340+
state_seen_in_before_agent["normal:from_run_async"] = (
1341+
callback_context.state.get("normal:from_run_async")
1342+
)
1343+
return None
1344+
1345+
# Setup
1346+
session_service = InMemorySessionService()
1347+
plugin = StateAccessPlugin(name="state_access")
1348+
1349+
agent = MockAgent(name="test_agent")
1350+
runner = Runner(
1351+
app_name=TEST_APP_ID,
1352+
agent=agent,
1353+
session_service=session_service,
1354+
plugins=[plugin],
1355+
auto_create_session=True,
1356+
)
1357+
1358+
# Run the agent with state_delta containing both temp and normal keys
1359+
async for _ in runner.run_async(
1360+
user_id=TEST_USER_ID,
1361+
session_id=TEST_SESSION_ID,
1362+
new_message=types.Content(
1363+
role="user", parts=[types.Part(text="test message")]
1364+
),
1365+
state_delta={
1366+
"temp:from_run_async": "temp_value",
1367+
"normal:from_run_async": "normal_value",
1368+
},
1369+
):
1370+
pass
1371+
1372+
# Verify temp state from state_delta WAS accessible during callbacks
1373+
assert (
1374+
state_seen_in_before_agent["temp:from_run_async"] == "temp_value"
1375+
), "temp: state from state_delta should be accessible in callbacks"
1376+
assert state_seen_in_before_agent["normal:from_run_async"] == "normal_value"
1377+
1378+
# Verify temp state is NOT persisted in the session
1379+
session = await session_service.get_session(
1380+
app_name=TEST_APP_ID,
1381+
user_id=TEST_USER_ID,
1382+
session_id=TEST_SESSION_ID,
1383+
)
1384+
1385+
# Normal state should be persisted
1386+
assert session.state.get("normal:from_run_async") == "normal_value"
1387+
1388+
# Temp state should NOT be persisted
1389+
assert "temp:from_run_async" not in session.state
1390+
1391+
# Verify temp state is also not in any event's state_delta
1392+
for event in session.events:
1393+
if event.actions and event.actions.state_delta:
1394+
assert "temp:from_run_async" not in event.actions.state_delta
1395+
1396+
13251397
if __name__ == "__main__":
13261398
pytest.main([__file__])

0 commit comments

Comments
 (0)