Skip to content

Commit 26cb42d

Browse files
committed
feat: Updating session tests
1 parent 0af7fb9 commit 26cb42d

11 files changed

+1110
-44
lines changed

src/strands/session/agent_session_manager.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Agent session manager implementation."""
22

33
import logging
4-
from typing import TYPE_CHECKING
54

65
from ..agent.state import AgentState
76
from ..experimental.hooks.events import AgentInitializedEvent, MessageAddedEvent
@@ -18,9 +17,6 @@
1817

1918
logger = logging.getLogger(__name__)
2019

21-
if TYPE_CHECKING:
22-
pass
23-
2420
DEFAULT_SESSION_AGENT_ID = "default"
2521

2622

@@ -41,9 +37,6 @@ def __init__(
4137
logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id)
4238
session = create_session(session_id=session_id, session_type=SessionType.AGENT)
4339
session_repository.create_session(session)
44-
else:
45-
if session["session_type"] != SessionType.AGENT:
46-
raise ValueError(f"Invalid session type: {session.session_type}")
4740

4841
self.session = session
4942
self._default_agent_initialized = False

src/strands/session/file_session_manager.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,14 @@ def _read_file(self, path: str) -> dict[str, Any]:
5050
try:
5151
with open(path, "r", encoding="utf-8") as f:
5252
return cast(dict[str, Any], json.load(f))
53-
except FileNotFoundError as e:
54-
raise SessionException(f"File not found: {path}") from e
5553
except json.JSONDecodeError as e:
5654
raise SessionException(f"Invalid JSON in file {path}: {e}") from e
5755

5856
def _write_file(self, path: str, data: dict[str, Any]) -> None:
5957
"""Write JSON file."""
60-
try:
61-
os.makedirs(os.path.dirname(path), exist_ok=True)
62-
with open(path, "w", encoding="utf-8") as f:
63-
json.dump(data, f, indent=2, ensure_ascii=False)
64-
except Exception as e:
65-
raise SessionException(f"Failed to write file {path}: {e}") from e
58+
os.makedirs(os.path.dirname(path), exist_ok=True)
59+
with open(path, "w", encoding="utf-8") as f:
60+
json.dump(data, f, indent=2, ensure_ascii=False)
6661

6762
def create_session(self, session: Session) -> Session:
6863
"""Create a new session."""
@@ -132,12 +127,11 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
132127
session_dict = cast(dict, session_message)
133128
self._write_file(message_file, session_dict)
134129

135-
def read_message(self, session_id: str, agent_id: str, message_id: str) -> SessionMessage:
130+
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
136131
"""Read message data."""
137132
message_file = self._get_message_path(session_id, agent_id, message_id)
138133
if not os.path.exists(message_file):
139-
raise SessionException(f"Message {message_id} does not exist for agent {agent_id} in session {session_id}")
140-
134+
return None
141135
message_data = self._read_file(message_file)
142136
return SessionMessage(**message_data) # type: ignore
143137

@@ -154,7 +148,7 @@ def list_messages(
154148
"""List messages for an agent with pagination."""
155149
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
156150
if not os.path.exists(messages_dir):
157-
return []
151+
raise SessionException("messages directory missing from agent: %s in session %s", agent_id, session_id)
158152

159153
# Get all message files and sort by creation time (newest first)
160154
message_files = []
@@ -175,6 +169,8 @@ def list_messages(
175169
# Read message data
176170
messages: list[SessionMessage] = []
177171
for file_path, _ in message_files:
172+
if not os.path.exists(file_path):
173+
continue
178174
message_data = self._read_file(file_path)
179175
messages.append(SessionMessage(**message_data)) # type: ignore
180176

src/strands/session/s3_session_manager.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""S3-based session DAO for cloud storage."""
1+
"""S3-based session manager for cloud storage."""
22

33
import json
44
from typing import Any, Dict, List, Optional, cast
@@ -18,7 +18,7 @@
1818

1919

2020
class S3SessionManager(AgentSessionManager, SessionRepository):
21-
"""S3-based session DAO for cloud storage."""
21+
"""S3-based session manager for cloud storage."""
2222

2323
def __init__(
2424
self,
@@ -29,7 +29,7 @@ def __init__(
2929
boto_client_config: Optional[BotocoreConfig] = None,
3030
region_name: Optional[str] = None,
3131
):
32-
"""Initialize S3SessionDAO with S3 storage.
32+
"""Initialize S3SessionManager with S3 storage.
3333
3434
Args:
3535
session_id: ID for the session
@@ -163,6 +163,10 @@ def read_agent(self, session_id: str, agent_id: str) -> Optional[SessionAgent]:
163163
def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
164164
"""Update agent data in S3."""
165165
agent_id = session_agent["agent_id"]
166+
previous_agent = self.read_agent(session_id=session_id, agent_id=agent_id)
167+
if previous_agent is None:
168+
raise SessionException(f"Agent {agent_id} in session {session_id} does not exist")
169+
session_agent["created_at"] = previous_agent["created_at"]
166170
agent_dict = cast(dict, session_agent)
167171
agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json"
168172
self._write_s3_object(agent_key, agent_dict)
@@ -185,6 +189,12 @@ def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optio
185189
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
186190
"""Update message data in S3."""
187191
message_id = session_message["message_id"]
192+
previous_message = self.read_message(
193+
session_id=session_id, agent_id=agent_id, message_id=session_message["message_id"]
194+
)
195+
if previous_message is None:
196+
raise SessionException(f"Message {message_id} does not exist")
197+
session_message["created_at"] = previous_message["created_at"]
188198
message_dict = cast(dict, session_message)
189199
message_key = self._get_message_path(session_id, agent_id, message_id)
190200
self._write_s3_object(message_key, message_dict)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from strands.session.session_repository import SessionRepository
2+
from strands.types.exceptions import SessionException
3+
4+
5+
class MockedSessionRepository(SessionRepository):
6+
"""Mock repository for testing."""
7+
8+
def __init__(self):
9+
"""Initialize with empty storage."""
10+
self.sessions = {}
11+
self.agents = {}
12+
self.messages = {}
13+
14+
def create_session(self, session):
15+
"""Create a session."""
16+
session_id = session["session_id"]
17+
if session_id in self.sessions:
18+
raise SessionException(f"Session {session_id} already exists")
19+
self.sessions[session_id] = session
20+
self.agents[session_id] = {}
21+
self.messages[session_id] = {}
22+
return session
23+
24+
def read_session(self, session_id):
25+
"""Read a session."""
26+
return self.sessions.get(session_id)
27+
28+
def create_agent(self, session_id, session_agent):
29+
"""Create an agent."""
30+
agent_id = session_agent["agent_id"]
31+
if session_id not in self.sessions:
32+
raise SessionException(f"Session {session_id} does not exist")
33+
if agent_id in self.agents.get(session_id, {}):
34+
raise SessionException(f"Agent {agent_id} already exists in session {session_id}")
35+
self.agents.setdefault(session_id, {})[agent_id] = session_agent
36+
self.messages.setdefault(session_id, {}).setdefault(agent_id, [])
37+
return session_agent
38+
39+
def read_agent(self, session_id, agent_id):
40+
"""Read an agent."""
41+
if session_id not in self.sessions:
42+
return None
43+
return self.agents.get(session_id, {}).get(agent_id)
44+
45+
def update_agent(self, session_id, session_agent):
46+
"""Update an agent."""
47+
agent_id = session_agent["agent_id"]
48+
if session_id not in self.sessions:
49+
raise SessionException(f"Session {session_id} does not exist")
50+
if agent_id not in self.agents.get(session_id, {}):
51+
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
52+
self.agents[session_id][agent_id] = session_agent
53+
54+
def create_message(self, session_id, agent_id, session_message):
55+
"""Create a message."""
56+
if session_id not in self.sessions:
57+
raise SessionException(f"Session {session_id} does not exist")
58+
if agent_id not in self.agents.get(session_id, {}):
59+
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
60+
self.messages.setdefault(session_id, {}).setdefault(agent_id, []).append(session_message)
61+
62+
def read_message(self, session_id, agent_id, message_id):
63+
"""Read a message."""
64+
if session_id not in self.sessions:
65+
return None
66+
if agent_id not in self.agents.get(session_id, {}):
67+
return None
68+
for message in self.messages.get(session_id, {}).get(agent_id, []):
69+
if message["message_id"] == message_id:
70+
return message
71+
return None
72+
73+
def update_message(self, session_id, agent_id, session_message):
74+
"""Update a message."""
75+
message_id = session_message["message_id"]
76+
if session_id not in self.sessions:
77+
raise SessionException(f"Session {session_id} does not exist")
78+
if agent_id not in self.agents.get(session_id, {}):
79+
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
80+
81+
for i, message in enumerate(self.messages.get(session_id, {}).get(agent_id, [])):
82+
if message["message_id"] == message_id:
83+
self.messages[session_id][agent_id][i] = session_message
84+
return
85+
86+
raise SessionException(f"Message {message_id} does not exist")
87+
88+
def list_messages(self, session_id, agent_id, limit=None, offset=0):
89+
"""List messages."""
90+
if session_id not in self.sessions:
91+
return []
92+
if agent_id not in self.agents.get(session_id, {}):
93+
return []
94+
95+
messages = self.messages.get(session_id, {}).get(agent_id, [])
96+
if limit is not None:
97+
return messages[offset : offset + limit]
98+
return messages[offset:]

tests/strands/agent/test_agent.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
from strands.agent import AgentResult
1414
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
1515
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
16+
from strands.agent.state import AgentState
1617
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
1718
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel
19+
from strands.session.agent_session_manager import DEFAULT_SESSION_AGENT_ID, AgentSessionManager
20+
from strands.telemetry.metrics import EventLoopMetrics
1821
from strands.types.content import Messages
1922
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException
23+
from strands.types.session import Session, SessionAgent, SessionType
24+
from tests.fixtures.mock_session_repository import MockedSessionRepository
25+
from tests.fixtures.mocked_model_provider import MockedModelProvider
2026

2127

2228
@pytest.fixture
@@ -1313,6 +1319,11 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr
13131319
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
13141320

13151321

1322+
def test_agent_init_with_state_object():
1323+
agent = Agent(state=AgentState({"foo": "bar"}))
1324+
assert agent.state.get("foo") == "bar"
1325+
1326+
13161327
def test_non_dict_throws_error():
13171328
with pytest.raises(ValueError, match="state must be an AgentState object or a dict"):
13181329
agent = Agent(state={"object", object()})
@@ -1366,3 +1377,38 @@ def test_agent_state_get_breaks_deep_dict_reference():
13661377

13671378
# This will fail if AgentState reflects the updated reference
13681379
json.dumps(agent.state.get())
1380+
1381+
1382+
def test_agent_session_management():
1383+
mock_session_repository = MockedSessionRepository()
1384+
session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository)
1385+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}])
1386+
agent = Agent(session_manager=session_manager, model=model)
1387+
agent("Hello!")
1388+
1389+
1390+
def test_agent_restored_from_session_management():
1391+
mock_session_repository = MockedSessionRepository()
1392+
mock_session_repository.create_session(
1393+
Session(
1394+
session_id="123",
1395+
session_type=SessionType.AGENT,
1396+
created_at="2025-01-01T00:00:00Z",
1397+
updated_at="2025-01-01T00:00:00Z",
1398+
)
1399+
)
1400+
mock_session_repository.create_agent(
1401+
"123",
1402+
SessionAgent(
1403+
agent_id=DEFAULT_SESSION_AGENT_ID,
1404+
event_loop_metrics=EventLoopMetrics().to_dict(),
1405+
state={"foo": "bar"},
1406+
created_at="2025-01-01T00:00:00Z",
1407+
updated_at="2025-01-01T00:00:00Z",
1408+
),
1409+
)
1410+
session_manager = AgentSessionManager(session_id="123", session_repository=mock_session_repository)
1411+
1412+
agent = Agent(session_manager=session_manager)
1413+
1414+
assert agent.state.get("foo") == "bar"

tests/strands/session/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for session management."""

0 commit comments

Comments
 (0)