@@ -169,6 +169,9 @@ def __init__(self, **kwargs):
169169 self .user_id : Optional [str ] = kwargs .get ("user_id" , _DEFAULT_USER_ID )
170170 # The user ID.
171171
172+ self .session_id : Optional [str ] = kwargs .get ("session_id" )
173+ # The session ID.
174+
172175
173176class _StreamingRunResponse :
174177 """Response object for `streaming_agent_run_with_events` method.
@@ -181,6 +184,8 @@ def __init__(self, **kwargs):
181184 # List of generated events.
182185 self .artifacts : Optional [List [_Artifact ]] = kwargs .get ("artifacts" )
183186 # List of artifacts belonging to the session.
187+ self .session_id : Optional [str ] = kwargs .get ("session_id" )
188+ # The session ID.
184189
185190 def dump (self ) -> Dict [str , Any ]:
186191 from vertexai .agent_engines import _utils
@@ -194,6 +199,8 @@ def dump(self) -> Dict[str, Any]:
194199 result ["events" ].append (event_dict )
195200 if self .artifacts :
196201 result ["artifacts" ] = [artifact .dump () for artifact in self .artifacts ]
202+ if self .session_id :
203+ result ["session_id" ] = self .session_id
197204 return result
198205
199206
@@ -402,7 +409,10 @@ async def _init_session(
402409 auth = _Authorization (** auth )
403410 session_state [f"temp:{ auth_id } " ] = auth .access_token
404411
405- session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
412+ if request .session_id :
413+ session_id = request .session_id
414+ else :
415+ session_id = f"temp_session_{ random .randbytes (8 ).hex ()} "
406416 session = await session_service .create_session (
407417 app_name = self ._tmpl_attrs .get ("app_name" ),
408418 user_id = request .user_id ,
@@ -450,7 +460,9 @@ async def _convert_response_events(
450460 """Converts the events to the streaming run response object."""
451461 import collections
452462
453- result = _StreamingRunResponse (events = events , artifacts = [])
463+ result = _StreamingRunResponse (
464+ events = events , artifacts = [], session_id = session_id
465+ )
454466
455467 # Save the generated artifacts into the result object.
456468 artifact_versions = collections .defaultdict (list )
@@ -685,22 +697,35 @@ async def streaming_agent_run_with_events(self, request_json: str):
685697 request = _StreamRunRequest (** json .loads (request_json ))
686698 if not self ._tmpl_attrs .get ("in_memory_runner" ):
687699 self .set_up ()
688- if not self ._tmpl_attrs .get ("artifact_service" ):
689- self .set_up ()
690700 # Prepare the in-memory session.
691701 if not self ._tmpl_attrs .get ("in_memory_artifact_service" ):
692702 self .set_up ()
693703 if not self ._tmpl_attrs .get ("in_memory_session_service" ):
694704 self .set_up ()
695- session = await self ._init_session (
696- session_service = self ._tmpl_attrs .get ("in_memory_session_service" ),
697- artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" ),
698- request = request ,
699- )
705+ session_service = self ._tmpl_attrs .get ("in_memory_session_service" )
706+ artifact_service = self ._tmpl_attrs .get ("in_memory_artifact_service" )
707+ # Try to get the session, if it doesn't exist, create a new one.
708+ session = None
709+ if request .session_id :
710+ try :
711+ session = await session_service .get_session (
712+ app_name = self ._tmpl_attrs .get ("app_name" ),
713+ user_id = request .user_id ,
714+ session_id = request .session_id ,
715+ )
716+ except RuntimeError :
717+ pass
718+ if not session :
719+ # Fall back to create session if the session is not found.
720+ session = await self ._init_session (
721+ session_service = session_service ,
722+ artifact_service = artifact_service ,
723+ request = request ,
724+ )
700725 if not session :
701726 raise RuntimeError ("Session initialization failed." )
702727
703- # Run the agent.
728+ # Run the agent
704729 message_for_agent = types .Content (** request .message )
705730 try :
706731 async for event in self ._tmpl_attrs .get ("in_memory_runner" ).run_async (
@@ -712,15 +737,16 @@ async def streaming_agent_run_with_events(self, request_json: str):
712737 user_id = request .user_id ,
713738 session_id = session .id ,
714739 events = [event ],
715- artifact_service = self . _tmpl_attrs . get ( "in_memory_artifact_service" ) ,
740+ artifact_service = artifact_service ,
716741 )
717742 yield converted_event
718743 finally :
719- await self ._tmpl_attrs .get ("in_memory_session_service" ).delete_session (
720- app_name = self ._tmpl_attrs .get ("app_name" ),
721- user_id = request .user_id ,
722- session_id = session .id ,
723- )
744+ if session and not request .session_id :
745+ await session_service .delete_session (
746+ app_name = self ._tmpl_attrs .get ("app_name" ),
747+ user_id = request .user_id ,
748+ session_id = session .id ,
749+ )
724750
725751 async def async_get_session (
726752 self ,
0 commit comments