Skip to content

Commit 445dc18

Browse files
sasha-gitgcopybara-github
authored andcommitted
fix: remove duplicate session GET when using API server, unbreak auto_session_create when using API server
Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 874188082
1 parent 2dbd1f2 commit 445dc18

File tree

5 files changed

+203
-76
lines changed

5 files changed

+203
-76
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from ..errors.already_exists_error import AlreadyExistsError
6969
from ..errors.input_validation_error import InputValidationError
7070
from ..errors.not_found_error import NotFoundError
71+
from ..errors.session_not_found_error import SessionNotFoundError
7172
from ..evaluation.base_eval_service import InferenceConfig
7273
from ..evaluation.base_eval_service import InferenceRequest
7374
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
@@ -1558,53 +1559,68 @@ async def patch_memory(
15581559

15591560
@app.post("/run", response_model_exclude_none=True)
15601561
async def run_agent(req: RunAgentRequest) -> list[Event]:
1561-
session = await self.session_service.get_session(
1562-
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
1563-
)
1564-
if not session:
1565-
raise HTTPException(status_code=404, detail="Session not found")
15661562
runner = await self.get_runner_async(req.app_name)
1567-
async with Aclosing(
1568-
runner.run_async(
1569-
user_id=req.user_id,
1570-
session_id=req.session_id,
1571-
new_message=req.new_message,
1572-
state_delta=req.state_delta,
1573-
invocation_id=req.invocation_id,
1574-
)
1575-
) as agen:
1576-
events = [event async for event in agen]
1563+
try:
1564+
async with Aclosing(
1565+
runner.run_async(
1566+
user_id=req.user_id,
1567+
session_id=req.session_id,
1568+
new_message=req.new_message,
1569+
state_delta=req.state_delta,
1570+
invocation_id=req.invocation_id,
1571+
)
1572+
) as agen:
1573+
events = [event async for event in agen]
1574+
except SessionNotFoundError as e:
1575+
raise HTTPException(status_code=404, detail=str(e)) from e
15771576
logger.info("Generated %s events in agent run", len(events))
15781577
logger.debug("Events generated: %s", events)
15791578
return events
15801579

15811580
@app.post("/run_sse")
15821581
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
1583-
# SSE endpoint
1584-
session = await self.session_service.get_session(
1585-
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
1582+
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
1583+
runner = await self.get_runner_async(req.app_name)
1584+
agen = runner.run_async(
1585+
user_id=req.user_id,
1586+
session_id=req.session_id,
1587+
new_message=req.new_message,
1588+
state_delta=req.state_delta,
1589+
run_config=RunConfig(streaming_mode=stream_mode),
1590+
invocation_id=req.invocation_id,
15861591
)
1587-
if not session:
1588-
raise HTTPException(status_code=404, detail="Session not found")
1592+
1593+
# Eagerly advance the generator to trigger session validation
1594+
# before the streaming response is created. This lets us return
1595+
# a proper HTTP 404 for missing sessions without a redundant
1596+
# get_session call — the Runner's single _get_or_create_session
1597+
# call is the only one that runs.
1598+
first_event = None
1599+
first_error = None
1600+
try:
1601+
first_event = await anext(agen)
1602+
except SessionNotFoundError as e:
1603+
await agen.aclose()
1604+
raise HTTPException(status_code=404, detail=str(e)) from e
1605+
except StopAsyncIteration:
1606+
await agen.aclose()
1607+
except Exception as e:
1608+
first_error = e
15891609

15901610
# Convert the events to properly formatted SSE
15911611
async def event_generator():
1592-
try:
1593-
stream_mode = (
1594-
StreamingMode.SSE if req.streaming else StreamingMode.NONE
1595-
)
1596-
runner = await self.get_runner_async(req.app_name)
1597-
async with Aclosing(
1598-
runner.run_async(
1599-
user_id=req.user_id,
1600-
session_id=req.session_id,
1601-
new_message=req.new_message,
1602-
state_delta=req.state_delta,
1603-
run_config=RunConfig(streaming_mode=stream_mode),
1604-
invocation_id=req.invocation_id,
1605-
)
1606-
) as agen:
1607-
async for event in agen:
1612+
async with Aclosing(agen):
1613+
try:
1614+
if first_error:
1615+
raise first_error
1616+
1617+
async def all_events():
1618+
if first_event is not None:
1619+
yield first_event
1620+
async for event in agen:
1621+
yield event
1622+
1623+
async for event in all_events():
16081624
# ADK Web renders artifacts from `actions.artifactDelta`
16091625
# during part processing *and* during action processing
16101626
# 1) the original event with `artifactDelta` cleared (content)
@@ -1630,9 +1646,9 @@ async def event_generator():
16301646
"Generated event in agent run streaming: %s", sse_event
16311647
)
16321648
yield f"data: {sse_event}\n\n"
1633-
except Exception as e:
1634-
logger.exception("Error in event_generator: %s", e)
1635-
yield f"data: {json.dumps({'error': str(e)})}\n\n"
1649+
except Exception as e:
1650+
logger.exception("Error in event_generator: %s", e)
1651+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
16361652

16371653
# Returns a streaming response with the proper media type for SSE
16381654
return StreamingResponse(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from .not_found_error import NotFoundError
18+
19+
20+
class SessionNotFoundError(ValueError, NotFoundError):
21+
"""Raised when a session cannot be found.
22+
23+
Inherits from both ValueError (for backward compatibility) and NotFoundError
24+
(for semantic consistency with the project's error hierarchy).
25+
"""
26+
27+
def __init__(self, message="Session not found."):
28+
super().__init__(message)

src/google/adk/runners.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
4646
from .auth.credential_service.base_credential_service import BaseCredentialService
4747
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
48+
from .errors.session_not_found_error import SessionNotFoundError
4849
from .events.event import Event
4950
from .events.event import EventActions
5051
from .flows.llm_flows import contents
@@ -358,7 +359,7 @@ async def _get_or_create_session(
358359
359360
This helper first attempts to retrieve the session. If not found and
360361
auto_create_session is True, it creates a new session with the provided
361-
identifiers. Otherwise, it raises a ValueError with a helpful message.
362+
identifiers. Otherwise, it raises a SessionNotFoundError.
362363
363364
Args:
364365
user_id: The user ID of the session.
@@ -368,7 +369,8 @@ async def _get_or_create_session(
368369
The existing or newly created `Session`.
369370
370371
Raises:
371-
ValueError: If the session is not found and auto_create_session is False.
372+
SessionNotFoundError: If the session is not found and
373+
auto_create_session is False.
372374
"""
373375
session = await self.session_service.get_session(
374376
app_name=self.app_name, user_id=user_id, session_id=session_id
@@ -380,7 +382,7 @@ async def _get_or_create_session(
380382
)
381383
else:
382384
message = self._format_session_not_found_message(session_id)
383-
raise ValueError(message)
385+
raise SessionNotFoundError(message)
384386
return session
385387

386388
def run(

tests/unittests/cli/test_fast_api.py

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.adk.cli import fast_api as fast_api_module
3333
from google.adk.cli.fast_api import get_fast_api_app
3434
from google.adk.errors.input_validation_error import InputValidationError
35+
from google.adk.errors.session_not_found_error import SessionNotFoundError
3536
from google.adk.evaluation.eval_case import EvalCase
3637
from google.adk.evaluation.eval_case import Invocation
3738
from google.adk.evaluation.eval_result import EvalSetResult
@@ -451,18 +452,28 @@ def list_eval_set_results(self, app_name):
451452
return MockEvalSetResultsManager()
452453

453454

454-
@pytest.fixture
455-
def test_app(
455+
def _create_test_client(
456456
mock_session_service,
457457
mock_artifact_service,
458458
mock_memory_service,
459459
mock_agent_loader,
460460
mock_eval_sets_manager,
461461
mock_eval_set_results_manager,
462+
**app_kwargs,
462463
):
463-
"""Create a TestClient for the FastAPI app without starting a server."""
464-
465-
# Patch multiple services and signal handlers
464+
"""Helper to create a TestClient with the given get_fast_api_app overrides."""
465+
defaults = dict(
466+
agents_dir=".",
467+
web=True,
468+
session_service_uri="",
469+
artifact_service_uri="",
470+
memory_service_uri="",
471+
allow_origins=["*"],
472+
a2a=False,
473+
host="127.0.0.1",
474+
port=8000,
475+
)
476+
defaults.update(app_kwargs)
466477
with (
467478
patch.object(signal, "signal", autospec=True, return_value=None),
468479
patch.object(
@@ -502,23 +513,28 @@ def test_app(
502513
return_value=mock_eval_set_results_manager,
503514
),
504515
):
505-
# Get the FastAPI app, but don't actually run it
506-
app = get_fast_api_app(
507-
agents_dir=".",
508-
web=True,
509-
session_service_uri="",
510-
artifact_service_uri="",
511-
memory_service_uri="",
512-
allow_origins=["*"],
513-
a2a=False, # Disable A2A for most tests
514-
host="127.0.0.1",
515-
port=8000,
516-
)
516+
app = get_fast_api_app(**defaults)
517+
return TestClient(app)
517518

518-
# Create a TestClient that doesn't start a real server
519-
client = TestClient(app)
520519

521-
return client
520+
@pytest.fixture
521+
def test_app(
522+
mock_session_service,
523+
mock_artifact_service,
524+
mock_memory_service,
525+
mock_agent_loader,
526+
mock_eval_sets_manager,
527+
mock_eval_set_results_manager,
528+
):
529+
"""Create a TestClient for the FastAPI app without starting a server."""
530+
return _create_test_client(
531+
mock_session_service,
532+
mock_artifact_service,
533+
mock_memory_service,
534+
mock_agent_loader,
535+
mock_eval_sets_manager,
536+
mock_eval_set_results_manager,
537+
)
522538

523539

524540
@pytest.fixture
@@ -1106,20 +1122,9 @@ def test_agent_run_sse_yields_error_object_on_exception(
11061122
"""Test /run_sse streams an error object if streaming raises."""
11071123
info = create_test_session
11081124

1109-
async def run_async_raises(
1110-
self,
1111-
*,
1112-
user_id: str,
1113-
session_id: str,
1114-
invocation_id: Optional[str] = None,
1115-
new_message: Optional[types.Content] = None,
1116-
state_delta: Optional[dict[str, Any]] = None,
1117-
run_config: Optional[RunConfig] = None,
1118-
):
1119-
del user_id, session_id, invocation_id, new_message, state_delta, run_config
1125+
async def run_async_raises(self, **kwargs):
11201126
raise ValueError("boom")
1121-
if False: # pylint: disable=using-constant-test
1122-
yield _event_1()
1127+
yield # make it an async generator # pylint: disable=unreachable
11231128

11241129
monkeypatch.setattr(Runner, "run_async", run_async_raises)
11251130

@@ -1637,5 +1642,80 @@ def test_version_endpoint(test_app):
16371642
assert "language_version" in data
16381643

16391644

1645+
@pytest.fixture
1646+
def test_app_auto_session(
1647+
mock_session_service,
1648+
mock_artifact_service,
1649+
mock_memory_service,
1650+
mock_agent_loader,
1651+
mock_eval_sets_manager,
1652+
mock_eval_set_results_manager,
1653+
):
1654+
"""Create a TestClient with auto_create_session=True."""
1655+
return _create_test_client(
1656+
mock_session_service,
1657+
mock_artifact_service,
1658+
mock_memory_service,
1659+
mock_agent_loader,
1660+
mock_eval_sets_manager,
1661+
mock_eval_set_results_manager,
1662+
web=False,
1663+
auto_create_session=True,
1664+
)
1665+
1666+
1667+
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
1668+
def test_auto_creates_session(
1669+
test_app_auto_session, test_session_info, endpoint
1670+
):
1671+
"""Test /run and /run_sse auto-create sessions when auto_create_session=True."""
1672+
payload = {
1673+
"app_name": test_session_info["app_name"],
1674+
"user_id": test_session_info["user_id"],
1675+
"session_id": "nonexistent_session",
1676+
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
1677+
}
1678+
1679+
response = test_app_auto_session.post(endpoint, json=payload)
1680+
assert response.status_code == 200
1681+
1682+
if endpoint == "/run":
1683+
data = response.json()
1684+
assert isinstance(data, list)
1685+
assert len(data) > 0
1686+
else:
1687+
sse_events = [
1688+
json.loads(line.removeprefix("data: "))
1689+
for line in response.text.splitlines()
1690+
if line.startswith("data: ")
1691+
]
1692+
assert len(sse_events) > 0
1693+
assert not any("error" in e for e in sse_events)
1694+
1695+
1696+
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
1697+
def test_returns_404_without_auto_create(
1698+
test_app, test_session_info, monkeypatch, endpoint
1699+
):
1700+
"""Test /run and /run_sse return 404 for missing sessions without auto_create."""
1701+
1702+
async def run_async_session_not_found(self, **kwargs):
1703+
raise SessionNotFoundError(f"Session not found: {kwargs['session_id']}")
1704+
yield # make it an async generator # pylint: disable=unreachable
1705+
1706+
monkeypatch.setattr(Runner, "run_async", run_async_session_not_found)
1707+
1708+
payload = {
1709+
"app_name": test_session_info["app_name"],
1710+
"user_id": test_session_info["user_id"],
1711+
"session_id": "nonexistent_session",
1712+
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
1713+
}
1714+
1715+
response = test_app.post(endpoint, json=payload)
1716+
assert response.status_code == 404
1717+
assert "Session not found" in response.json()["detail"]
1718+
1719+
16401720
if __name__ == "__main__":
16411721
pytest.main(["-xvs", __file__])

tests/unittests/test_runners.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from google.adk.apps.app import ResumabilityConfig
3030
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
3131
from google.adk.cli.utils.agent_loader import AgentLoader
32+
from google.adk.errors.session_not_found_error import SessionNotFoundError
3233
from google.adk.events.event import Event
3334
from google.adk.plugins.base_plugin import BasePlugin
3435
from google.adk.runners import Runner
@@ -243,7 +244,7 @@ def _infer_agent_origin(
243244
new_message=types.Content(role="user", parts=[]),
244245
)
245246

246-
with pytest.raises(ValueError) as excinfo:
247+
with pytest.raises(SessionNotFoundError) as excinfo:
247248
await agen.__anext__()
248249

249250
await agen.aclose()

0 commit comments

Comments
 (0)