Skip to content

Commit 0a369ea

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add ability to use existing sessions for streaming_agent_run_with_events calls.
PiperOrigin-RevId: 816470728
1 parent 64b0665 commit 0a369ea

File tree

2 files changed

+83
-34
lines changed
  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

2 files changed

+83
-34
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def __init__(self, **kwargs):
169169
self.user_id: Optional[str] = kwargs.get("user_id", _DEFAULT_USER_ID)
170170
# The user ID.
171171

172+
self.session_id: Optional[str] = kwargs.get("session_id")
173+
# The session ID.
174+
172175

173176
class _StreamingRunResponse:
174177
"""Response object for `streaming_agent_run_with_events` method.
@@ -181,6 +184,8 @@ def __init__(self, **kwargs):
181184
# List of generated events.
182185
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
183186
# List of artifacts belonging to the session.
187+
self.session_id: Optional[str] = kwargs.get("session_id")
188+
# The session ID.
184189

185190
def dump(self) -> Dict[str, Any]:
186191
from vertexai.agent_engines import _utils
@@ -194,6 +199,8 @@ def dump(self) -> Dict[str, Any]:
194199
result["events"].append(event_dict)
195200
if self.artifacts:
196201
result["artifacts"] = [artifact.dump() for artifact in self.artifacts]
202+
if self.session_id:
203+
result["session_id"] = self.session_id
197204
return result
198205

199206

@@ -402,7 +409,10 @@ async def _init_session(
402409
auth = _Authorization(**auth)
403410
session_state[f"temp:{auth_id}"] = auth.access_token
404411

405-
session_id = f"temp_session_{random.randbytes(8).hex()}"
412+
if request.session_id:
413+
session_id = request.session_id
414+
else:
415+
session_id = f"temp_session_{random.randbytes(8).hex()}"
406416
session = await session_service.create_session(
407417
app_name=self._tmpl_attrs.get("app_name"),
408418
user_id=request.user_id,
@@ -450,7 +460,9 @@ async def _convert_response_events(
450460
"""Converts the events to the streaming run response object."""
451461
import collections
452462

453-
result = _StreamingRunResponse(events=events, artifacts=[])
463+
result = _StreamingRunResponse(
464+
events=events, artifacts=[], session_id=session_id
465+
)
454466

455467
# Save the generated artifacts into the result object.
456468
artifact_versions = collections.defaultdict(list)
@@ -685,22 +697,35 @@ async def streaming_agent_run_with_events(self, request_json: str):
685697
request = _StreamRunRequest(**json.loads(request_json))
686698
if not self._tmpl_attrs.get("in_memory_runner"):
687699
self.set_up()
688-
if not self._tmpl_attrs.get("artifact_service"):
689-
self.set_up()
690700
# Prepare the in-memory session.
691701
if not self._tmpl_attrs.get("in_memory_artifact_service"):
692702
self.set_up()
693703
if not self._tmpl_attrs.get("in_memory_session_service"):
694704
self.set_up()
695-
session = await self._init_session(
696-
session_service=self._tmpl_attrs.get("in_memory_session_service"),
697-
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
698-
request=request,
699-
)
705+
session_service = self._tmpl_attrs.get("in_memory_session_service")
706+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
707+
# Try to get the session, if it doesn't exist, create a new one.
708+
session = None
709+
if request.session_id:
710+
try:
711+
session = await session_service.get_session(
712+
app_name=self._tmpl_attrs.get("app_name"),
713+
user_id=request.user_id,
714+
session_id=request.session_id,
715+
)
716+
except RuntimeError:
717+
pass
718+
if not session:
719+
# Fall back to create session if the session is not found.
720+
session = await self._init_session(
721+
session_service=session_service,
722+
artifact_service=artifact_service,
723+
request=request,
724+
)
700725
if not session:
701726
raise RuntimeError("Session initialization failed.")
702727

703-
# Run the agent.
728+
# Run the agent
704729
message_for_agent = types.Content(**request.message)
705730
try:
706731
async for event in self._tmpl_attrs.get("in_memory_runner").run_async(
@@ -712,15 +737,16 @@ async def streaming_agent_run_with_events(self, request_json: str):
712737
user_id=request.user_id,
713738
session_id=session.id,
714739
events=[event],
715-
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
740+
artifact_service=artifact_service,
716741
)
717742
yield converted_event
718743
finally:
719-
await self._tmpl_attrs.get("in_memory_session_service").delete_session(
720-
app_name=self._tmpl_attrs.get("app_name"),
721-
user_id=request.user_id,
722-
session_id=session.id,
723-
)
744+
if session and not request.session_id:
745+
await session_service.delete_session(
746+
app_name=self._tmpl_attrs.get("app_name"),
747+
user_id=request.user_id,
748+
session_id=session.id,
749+
)
724750

725751
async def async_get_session(
726752
self,

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ def __init__(self, **kwargs):
183183
self.user_id: Optional[str] = kwargs.get("user_id", _DEFAULT_USER_ID)
184184
# The user ID.
185185

186+
self.session_id: Optional[str] = kwargs.get("session_id")
187+
# The session ID.
188+
186189

187190
class _StreamingRunResponse:
188191
"""Response object for `streaming_agent_run_with_events` method.
@@ -195,6 +198,8 @@ def __init__(self, **kwargs):
195198
# List of generated events.
196199
self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts")
197200
# List of artifacts belonging to the session.
201+
self.session_id: Optional[str] = kwargs.get("session_id")
202+
# The session ID.
198203

199204
def dump(self) -> Dict[str, Any]:
200205
from vertexai.agent_engines import _utils
@@ -208,6 +213,8 @@ def dump(self) -> Dict[str, Any]:
208213
result["events"].append(event_dict)
209214
if self.artifacts:
210215
result["artifacts"] = [artifact.dump() for artifact in self.artifacts]
216+
if self.session_id:
217+
result["session_id"] = self.session_id
211218
return result
212219

213220

@@ -383,7 +390,10 @@ async def _init_session(
383390
auth = _Authorization(**auth)
384391
session_state[f"temp:{auth_id}"] = auth.access_token
385392

386-
session_id = f"temp_session_{random.randbytes(8).hex()}"
393+
if request.session_id:
394+
session_id = request.session_id
395+
else:
396+
session_id = f"temp_session_{random.randbytes(8).hex()}"
387397
session = await session_service.create_session(
388398
app_name=self._tmpl_attrs.get("app_name"),
389399
user_id=request.user_id,
@@ -431,7 +441,9 @@ async def _convert_response_events(
431441
"""Converts the events to the streaming run response object."""
432442
import collections
433443

434-
result = _StreamingRunResponse(events=events, artifacts=[])
444+
result = _StreamingRunResponse(
445+
events=events, artifacts=[], session_id=session_id
446+
)
435447

436448
# Save the generated artifacts into the result object.
437449
artifact_versions = collections.defaultdict(list)
@@ -735,21 +747,33 @@ async def _invoke_agent_async():
735747
request = _StreamRunRequest(**json.loads(request_json))
736748
if not self._tmpl_attrs.get("in_memory_runner"):
737749
self.set_up()
738-
if not self._tmpl_attrs.get("artifact_service"):
739-
self.set_up()
740750
# Prepare the in-memory session.
741751
if not self._tmpl_attrs.get("in_memory_artifact_service"):
742752
self.set_up()
743753
if not self._tmpl_attrs.get("in_memory_session_service"):
744754
self.set_up()
745-
session = await self._init_session(
746-
session_service=self._tmpl_attrs.get("in_memory_session_service"),
747-
artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"),
748-
request=request,
749-
)
755+
session_service = self._tmpl_attrs.get("in_memory_session_service")
756+
artifact_service = self._tmpl_attrs.get("in_memory_artifact_service")
757+
# Try to get the session, if it doesn't exist, create a new one.
758+
session = None
759+
if request.session_id:
760+
try:
761+
session = await session_service.get_session(
762+
app_name=self._tmpl_attrs.get("app_name"),
763+
user_id=request.user_id,
764+
session_id=request.session_id,
765+
)
766+
except RuntimeError:
767+
pass
768+
if not session:
769+
# Fall back to create session if the session is not found.
770+
session = await self._init_session(
771+
session_service=session_service,
772+
artifact_service=artifact_service,
773+
request=request,
774+
)
750775
if not session:
751776
raise RuntimeError("Session initialization failed.")
752-
753777
# Run the agent.
754778
message_for_agent = types.Content(**request.message)
755779
try:
@@ -762,17 +786,16 @@ async def _invoke_agent_async():
762786
user_id=request.user_id,
763787
session_id=session.id,
764788
events=[event],
765-
artifact_service=self._tmpl_attrs.get(
766-
"in_memory_artifact_service"
767-
),
789+
artifact_service=artifact_service,
768790
)
769791
event_queue.put(converted_event)
770792
finally:
771-
await self._tmpl_attrs.get("in_memory_session_service").delete_session(
772-
app_name=self._tmpl_attrs.get("app_name"),
773-
user_id=request.user_id,
774-
session_id=session.id,
775-
)
793+
if session and not request.session_id:
794+
await session_service.delete_session(
795+
app_name=self._tmpl_attrs.get("app_name"),
796+
user_id=request.user_id,
797+
session_id=session.id,
798+
)
776799

777800
def _asyncio_thread_main():
778801
try:

0 commit comments

Comments
 (0)