Skip to content

Commit 74d7ed6

Browse files
committed
feat: Session persistence
1 parent 75dbbad commit 74d7ed6

19 files changed

+2505
-3
lines changed

src/strands/agent/agent.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
2424
from ..handlers.tool_handler import AgentToolHandler
2525
from ..models.bedrock import BedrockModel
26+
from ..session.session_manager import SessionManager
2627
from ..telemetry.metrics import EventLoopMetrics
2728
from ..telemetry.tracer import get_tracer
2829
from ..tools.registry import ToolRegistry
@@ -192,8 +193,10 @@ def __init__(
192193
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
193194
*,
194195
name: Optional[str] = None,
196+
id: Optional[str] = None,
195197
description: Optional[str] = None,
196198
state: Optional[Union[AgentState, dict]] = None,
199+
session_manager: Optional["SessionManager"] = None,
197200
):
198201
"""Initialize the Agent with the specified configuration.
199202
@@ -228,10 +231,15 @@ def __init__(
228231
trace_attributes: Custom trace attributes to apply to the agent's trace span.
229232
name: name of the Agent
230233
Defaults to None.
234+
id: identifier for the agent, used by session manager.
235+
Defaults to uuid4().
231236
description: description of what the Agent does
232237
Defaults to None.
233238
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
234239
Defaults to an empty AgentState object.
240+
session_manager: Manager for handling agent sessions including conversation history and state.
241+
If provided, enables session-based persistence and state management. The session manager
242+
handles agent_id, and session identification internally.
235243
236244
Raises:
237245
ValueError: If max_parallel_tools is less than 1.
@@ -304,10 +312,18 @@ def __init__(
304312
else:
305313
self.state = AgentState()
306314

315+
# Initialize session management functionality
316+
self.session_manager = session_manager
317+
307318
self.tool_caller = Agent.ToolCaller(self)
308319
self.name = name
320+
self.id = id
309321
self.description = description
310322

323+
# Setup session callback handler if session is enabled
324+
if self.session_manager:
325+
self.session_manager.initialize_agent(self)
326+
311327
@property
312328
def tool(self) -> ToolCaller:
313329
"""Call tool as a function.
@@ -482,6 +498,10 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str,
482498
new_message: Message = {"role": "user", "content": message_content}
483499
self.messages.append(new_message)
484500

501+
# Save message if session manager is available
502+
if self.session_manager:
503+
self.session_manager.append_message_to_agent_session(self, new_message)
504+
485505
# Execute the event loop cycle with retry logic for context limits
486506
yield from self._execute_event_loop_cycle(kwargs)
487507

@@ -575,6 +595,13 @@ def _record_tool_execution(
575595
messages.append(tool_result_msg)
576596
messages.append(assistant_msg)
577597

598+
# Save to conversation manager if available
599+
if self.session_manager:
600+
self.session_manager.append_message_to_agent_session(self, user_msg)
601+
self.session_manager.append_message_to_agent_session(self, tool_use_msg)
602+
self.session_manager.append_message_to_agent_session(self, tool_result_msg)
603+
self.session_manager.append_message_to_agent_session(self, assistant_msg)
604+
578605
def _start_agent_trace_span(self, prompt: str) -> None:
579606
"""Starts a trace span for the agent.
580607

src/strands/agent/agent_result.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any
7+
from typing import Any, Optional
88

99
from ..telemetry.metrics import EventLoopMetrics
1010
from ..types.content import Message
@@ -20,12 +20,16 @@ class AgentResult:
2020
message: The last message generated by the agent.
2121
metrics: Performance metrics collected during processing.
2222
state: Additional state information from the event loop.
23+
name: (Optional) name of the agent if it is defined
24+
id: (Optional) id of the agent if it is defined
2325
"""
2426

2527
stop_reason: StopReason
2628
message: Message
2729
metrics: EventLoopMetrics
2830
state: Any
31+
name: Optional[str] = None
32+
id: Optional[str] = None
2933

3034
def __str__(self) -> str:
3135
"""Get the agent's last message as a string.

src/strands/agent/state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class AgentState:
1414
- Get/set/delete operations
1515
"""
1616

17+
<<<<<<< HEAD
1718
def __init__(self, initial_state: Optional[Dict[str, Any]] = None):
19+
=======
20+
def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None):
21+
>>>>>>> 1dbbfe6 (feat: Session persistence)
1822
"""Initialize AgentState."""
1923
self._state: Dict[str, Dict[str, Any]]
2024
if initial_state:

src/strands/event_loop/event_loop.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def event_loop_cycle(
181181

182182
# Add the response message to the conversation
183183
messages.append(message)
184-
yield {"callback": {"message": message}}
184+
callback_data = {"message": message}
185+
if "agent" in kwargs:
186+
callback_data["agent"] = kwargs["agent"]
187+
yield {"callback": callback_data}
185188

186189
# Update metrics
187190
event_loop_metrics.update_usage(usage)
@@ -390,7 +393,10 @@ def _handle_tool_execution(
390393
}
391394

392395
messages.append(tool_result_message)
393-
yield {"callback": {"message": tool_result_message}}
396+
callback_data = {"message": tool_result_message}
397+
if "agent" in kwargs:
398+
callback_data["agent"] = kwargs["agent"]
399+
yield {"callback": callback_data}
394400

395401
if cycle_span:
396402
tracer = get_tracer()
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""File-based implementation of session manager."""
2+
3+
import logging
4+
from typing import TYPE_CHECKING, Any, Optional
5+
6+
from ..agent.state import AgentState
7+
from ..handlers.callback_handler import CompositeCallbackHandler
8+
from ..types.content import Message
9+
from .exceptions import SessionException
10+
from .file_session_dao import FileSessionDAO
11+
from .session_dao import SessionDAO
12+
from .session_manager import SessionManager
13+
from .session_models import Session, SessionAgent, SessionMessage, SessionType
14+
15+
logger = logging.getLogger(__name__)
16+
17+
if TYPE_CHECKING:
18+
from ..agent.agent import Agent
19+
20+
21+
class AgentSessionManager(SessionManager):
22+
"""Session manager for a single Agent.
23+
24+
This implementation stores sessions as JSON files in a specified directory.
25+
Each session is stored in a separate file named by its session_id.
26+
"""
27+
28+
def __init__(
29+
self,
30+
session_id: str,
31+
session_dao: Optional[SessionDAO] = None,
32+
):
33+
"""Initialize the FileSessionManager."""
34+
self.session_dao = session_dao or FileSessionDAO()
35+
self.session_id = session_id
36+
37+
def append_message_to_agent_session(self, agent: "Agent", message: Message) -> None:
38+
"""Append a message to the agent's session.
39+
40+
Args:
41+
agent: The agent whose session to update
42+
message: The message to append
43+
"""
44+
if agent.id is None:
45+
raise ValueError("`agent.id` must be set before appending message to session.")
46+
47+
session_message = SessionMessage.from_dict(dict(message))
48+
self.session_dao.create_message(self.session_id, agent.id, session_message)
49+
self.session_dao.update_agent(
50+
self.session_id,
51+
SessionAgent(
52+
agent_id=agent.id,
53+
session_id=self.session_id,
54+
event_loop_metrics=agent.event_loop_metrics.to_dict(),
55+
state=agent.state.get(),
56+
),
57+
)
58+
59+
def initialize_agent(self, agent: "Agent") -> None:
60+
"""Restore agent data from the current session.
61+
62+
Args:
63+
agent: Agent instance to restore session data to
64+
65+
Raises:
66+
SessionException: If restore operation fails
67+
"""
68+
if agent.id is None:
69+
raise ValueError("`agent.id` must be set before initializing session.")
70+
71+
try:
72+
# Try to read existing session
73+
session = self.session_dao.read_session(self.session_id)
74+
75+
if session.session_type != SessionType.AGENT:
76+
raise ValueError(f"Invalid session type: {session.session_type}")
77+
78+
if agent.id not in [agent.agent_id for agent in self.session_dao.list_agents(self.session_id)]:
79+
raise ValueError(f"Agent {agent.id} not found in session {self.session_id}")
80+
81+
# Initialize agent
82+
agent.messages = [
83+
session_message.to_message()
84+
for session_message in self.session_dao.list_messages(self.session_id, agent.id)
85+
]
86+
agent.state = AgentState(self.session_dao.read_agent(self.session_id, agent.id).state)
87+
88+
except SessionException:
89+
# Session doesn't exist, create new one
90+
logger.debug("Session not found, creating new session")
91+
# Session doesn't exist, create new one
92+
session = Session(session_id=self.session_id, session_type=SessionType.AGENT)
93+
session_agent = SessionAgent(
94+
agent_id=agent.id,
95+
session_id=self.session_id,
96+
event_loop_metrics=agent.event_loop_metrics.to_dict(),
97+
state=agent.state.get(),
98+
)
99+
self.session_dao.create_session(session)
100+
self.session_dao.create_agent(self.session_id, session_agent)
101+
for message in agent.messages:
102+
session_message = SessionMessage.from_dict(dict(message))
103+
self.session_dao.create_message(self.session_id, agent.id, session_message)
104+
105+
self.session = session
106+
107+
# Attach a callback handler for persisting messages
108+
def session_callback(**kwargs: Any) -> None:
109+
try:
110+
# Handle message persistence
111+
if "message" in kwargs:
112+
message = kwargs["message"]
113+
self.append_message_to_agent_session(kwargs["agent"], message)
114+
except Exception as e:
115+
logger.error("Persistence operation failed", e)
116+
117+
agent.callback_handler = CompositeCallbackHandler(agent.callback_handler, session_callback)

src/strands/session/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Exception classes for session management operations."""
2+
3+
4+
class SessionException(Exception):
5+
"""Exception raised when session operations fail."""

0 commit comments

Comments
 (0)