Skip to content

Commit 0bc4cd0

Browse files
committed
refactor: add pr feedback
1 parent 8b267d2 commit 0bc4cd0

11 files changed

+228
-101
lines changed

src/strands/agent/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ def __init__(
318318
self.hooks = HookRegistry()
319319

320320
# Initialize session management functionality
321-
self.session_manager = session_manager
322-
if self.session_manager:
323-
self.hooks.add_hook(self.session_manager)
321+
self._session_manager = session_manager
322+
if self._session_manager:
323+
self.hooks.add_hook(self._session_manager)
324324

325325
if hooks:
326326
for hook in hooks:

src/strands/session/file_session_manager.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..types.exceptions import SessionException
1212
from ..types.session import Session, SessionAgent, SessionMessage
13-
from .agent_session_manager import AgentSessionManager
13+
from .repository_session_manager import RepositorySessionManager
1414
from .session_repository import SessionRepository
1515

1616
logger = logging.getLogger(__name__)
@@ -20,8 +20,21 @@
2020
MESSAGE_PREFIX = "message_"
2121

2222

23-
class FileSessionManager(AgentSessionManager, SessionRepository):
24-
"""File-based session manager for local filesystem storage."""
23+
class FileSessionManager(RepositorySessionManager, SessionRepository):
24+
"""File-based session manager for local filesystem storage.
25+
26+
Creates the following filesystem structure for the session storage:
27+
/<sessions_dir>/
28+
└── session_<session_id>/
29+
├── session.json # Session metadata
30+
└── agents/
31+
└── agent_<agent_id>/
32+
├── agent.json # Agent metadata
33+
└── messages/
34+
├── message_<created_timestamp>_<id1>.json
35+
└── message_<created_timestamp>_<id2>.json
36+
37+
"""
2538

2639
def __init__(self, session_id: str, storage_dir: Optional[str] = None):
2740
"""Initialize FileSession with filesystem storage.
@@ -44,10 +57,22 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
4457
session_path = self._get_session_path(session_id)
4558
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
4659

47-
def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str:
48-
"""Get message file path."""
60+
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
61+
"""Get message file path.
62+
63+
Args:
64+
session_id: ID of the session
65+
agent_id: ID of the agent
66+
message_id: ID of the message
67+
timestamp: ISO format timestamp to include in filename for sorting
68+
Returns:
69+
The filename for the message
70+
"""
4971
agent_path = self._get_agent_path(session_id, agent_id)
50-
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
72+
# Use timestamp for sortable filenames
73+
# Replace colons and periods in ISO format with underscores for filesystem compatibility
74+
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
75+
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json")
5176

5277
def _read_file(self, path: str) -> dict[str, Any]:
5378
"""Read JSON file."""
@@ -135,17 +160,26 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
135160
session_id,
136161
agent_id,
137162
session_message.message_id,
163+
session_message.created_at,
138164
)
139165
session_dict = asdict(session_message)
140166
self._write_file(message_file, session_dict)
141167

142168
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
143169
"""Read message data."""
144-
message_file = self._get_message_path(session_id, agent_id, message_id)
145-
if not os.path.exists(message_file):
170+
# Get the messages directory
171+
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
172+
if not os.path.exists(messages_dir):
146173
return None
147-
message_data = self._read_file(message_file)
148-
return SessionMessage.from_dict(message_data)
174+
175+
# List files in messages directory, and check if the filename ends with the message id
176+
for filename in os.listdir(messages_dir):
177+
if filename.endswith(f"{message_id}.json"):
178+
file_path = os.path.join(messages_dir, filename)
179+
message_data = self._read_file(file_path)
180+
return SessionMessage.from_dict(message_data)
181+
182+
return None
149183

150184
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
151185
"""Update message data."""
@@ -156,7 +190,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
156190

157191
# Preserve the original created_at timestamp
158192
session_message.created_at = previous_message.created_at
159-
message_file = self._get_message_path(session_id, agent_id, message_id)
193+
message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
160194
self._write_file(message_file, asdict(session_message))
161195

162196
def list_messages(
@@ -168,20 +202,25 @@ def list_messages(
168202
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")
169203

170204
# Read all message files
171-
messages: list[SessionMessage] = []
205+
message_files: list[str] = []
172206
for filename in os.listdir(messages_dir):
173207
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
174-
file_path = os.path.join(messages_dir, filename)
175-
message_data = self._read_file(file_path)
176-
messages.append(SessionMessage.from_dict(message_data))
208+
message_files.append(filename)
177209

178-
# Sort by created_at timestamp (oldest first)
179-
messages.sort(key=lambda x: x.created_at)
210+
# Sort filenames - the timestamp in the file's name will sort chronologically
211+
message_files.sort()
180212

181-
# Apply pagination
213+
# Apply pagination to filenames
182214
if limit is not None:
183-
messages = messages[offset : offset + limit]
215+
message_files = message_files[offset : offset + limit]
184216
else:
185-
messages = messages[offset:]
217+
message_files = message_files[offset:]
218+
219+
# Load only the message files
220+
messages: list[SessionMessage] = []
221+
for filename in message_files:
222+
file_path = os.path.join(messages_dir, filename)
223+
message_data = self._read_file(file_path)
224+
messages.append(SessionMessage.from_dict(message_data))
186225

187226
return messages

src/strands/session/agent_session_manager.py renamed to src/strands/session/repository_session_manager.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
"""Agent session manager implementation."""
1+
"""Repository session manager implementation."""
22

33
import logging
44

5-
from ..agent.agent import _DEFAULT_AGENT_ID, Agent
5+
from ..agent.agent import Agent
66
from ..agent.state import AgentState
77
from ..types.content import Message
88
from ..types.exceptions import SessionException
@@ -17,29 +17,38 @@
1717

1818
logger = logging.getLogger(__name__)
1919

20-
DEFAULT_SESSION_AGENT_ID = "default"
2120

22-
23-
class AgentSessionManager(SessionManager):
24-
"""Session manager for persisting agent's in a Session."""
21+
class RepositorySessionManager(SessionManager):
22+
"""Session manager for persisting agents in a SessionRepository."""
2523

2624
def __init__(
2725
self,
2826
session_id: str,
2927
session_repository: SessionRepository,
3028
):
31-
"""Initialize the AgentSessionManager."""
29+
"""Initialize the RepositorySessionManager.
30+
31+
If no session with the specified session_id exists yet, it will be created
32+
in the session_repository.
33+
34+
Args:
35+
session_id: ID to use for the session. A new session with this id will be created if it does
36+
not exist in the reposiory yet
37+
session_repository: Underlying session repository to use to store the sessions state.
38+
"""
3239
self.session_repository = session_repository
3340
self.session_id = session_id
3441
session = session_repository.read_session(session_id)
3542
# Create a session if it does not exist yet
3643
if session is None:
37-
logger.debug("session_id=<%s> | Session not found, creating new session.", self.session_id)
44+
logger.debug("session_id=<%s> | session not found, creating new session", self.session_id)
3845
session = Session(session_id=session_id, session_type=SessionType.AGENT)
3946
session_repository.create_session(session)
4047

4148
self.session = session
42-
self._default_agent_initialized = False
49+
50+
# Keep track of the initialized agent id's so that two agents in a session cannot share an id
51+
self._initialized_agent_ids: set[str] = set()
4352

4453
def append_message(self, message: Message, agent: Agent) -> None:
4554
"""Append a message to the agent's session.
@@ -49,12 +58,10 @@ def append_message(self, message: Message, agent: Agent) -> None:
4958
agent: Agent to append the message to
5059
"""
5160
session_message = SessionMessage.from_message(message)
52-
if agent.agent_id is None:
53-
raise ValueError("`agent.agent_id` must be set before appending message to session.")
5461
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
5562

5663
def sync_agent(self, agent: Agent) -> None:
57-
"""Sync agent to the session.
64+
"""Serialize and update the agent into the session repository.
5865
5966
Args:
6067
agent: Agent to sync to the session.
@@ -70,16 +77,15 @@ def initialize(self, agent: Agent) -> None:
7077
Args:
7178
agent: Agent to initialize from the session
7279
"""
73-
if agent.agent_id is _DEFAULT_AGENT_ID:
74-
if self._default_agent_initialized:
75-
raise SessionException("Set `agent_id` to support more than one agent in a session.")
76-
self._default_agent_initialized = True
80+
if agent.agent_id in self._initialized_agent_ids:
81+
raise SessionException("The `agent_id` of an agent must be unique in a session.")
82+
self._initialized_agent_ids.add(agent.agent_id)
7783

7884
session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)
7985

8086
if session_agent is None:
8187
logger.debug(
82-
"agent_id=<%s> | session_id=<%s> | Creating agent.",
88+
"agent_id=<%s> | session_id=<%s> | creating agent",
8389
agent.agent_id,
8490
self.session_id,
8591
)
@@ -91,7 +97,7 @@ def initialize(self, agent: Agent) -> None:
9197
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
9298
else:
9399
logger.debug(
94-
"agent_id=<%s> | session_id=<%s> | Restoring agent.",
100+
"agent_id=<%s> | session_id=<%s> | restoring agent",
95101
agent.agent_id,
96102
self.session_id,
97103
)

src/strands/session/s3_session_manager.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ..types.exceptions import SessionException
1313
from ..types.session import Session, SessionAgent, SessionMessage
14-
from .agent_session_manager import AgentSessionManager
14+
from .repository_session_manager import RepositorySessionManager
1515
from .session_repository import SessionRepository
1616

1717
logger = logging.getLogger(__name__)
@@ -21,8 +21,21 @@
2121
MESSAGE_PREFIX = "message_"
2222

2323

24-
class S3SessionManager(AgentSessionManager, SessionRepository):
25-
"""S3-based session manager for cloud storage."""
24+
class S3SessionManager(RepositorySessionManager, SessionRepository):
25+
"""S3-based session manager for cloud storage.
26+
27+
Creates the following filesystem structure for the session storage:
28+
/<sessions_dir>/
29+
└── session_<session_id>/
30+
├── session.json # Session metadata
31+
└── agents/
32+
└── agent_<agent_id>/
33+
├── agent.json # Agent metadata
34+
└── messages/
35+
├── message_<created_timestamp>_<id1>.json
36+
└── message_<created_timestamp>_<id2>.json
37+
38+
"""
2639

2740
def __init__(
2841
self,
@@ -72,10 +85,22 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
7285
session_path = self._get_session_path(session_id)
7386
return f"{session_path}agents/{AGENT_PREFIX}{agent_id}/"
7487

75-
def _get_message_path(self, session_id: str, agent_id: str, message_id: str) -> str:
76-
"""Get message S3 key."""
88+
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
89+
"""Get message S3 key.
90+
91+
Args:
92+
session_id: ID of the session
93+
agent_id: ID of the agent
94+
message_id: ID of the message
95+
timestamp: ISO format timestamp to include in key for sorting
96+
Returns:
97+
The key for the message
98+
"""
7799
agent_path = self._get_agent_path(session_id, agent_id)
78-
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"
100+
# Use timestamp for sortable keys
101+
# Replace colons and periods in ISO format with underscores for filesystem compatibility
102+
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
103+
return f"{agent_path}messages/{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json"
79104

80105
def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]:
81106
"""Read JSON object from S3."""
@@ -180,16 +205,29 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio
180205
"""Create a new message in S3."""
181206
message_id = session_message.message_id
182207
message_dict = asdict(session_message)
183-
message_key = self._get_message_path(session_id, agent_id, message_id)
208+
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
184209
self._write_s3_object(message_key, message_dict)
185210

186211
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
187212
"""Read message data from S3."""
188-
message_key = self._get_message_path(session_id, agent_id, message_id)
189-
message_data = self._read_s3_object(message_key)
190-
if message_data is None:
213+
# Get the messages prefix
214+
messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/"
215+
try:
216+
paginator = self.client.get_paginator("list_objects_v2")
217+
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)
218+
219+
for page in pages:
220+
if "Contents" in page:
221+
for obj in page["Contents"]:
222+
if obj["Key"].endswith(f"{message_id}.json"):
223+
message_data = self._read_s3_object(obj["Key"])
224+
if message_data:
225+
return SessionMessage.from_dict(message_data)
226+
191227
return None
192-
return SessionMessage.from_dict(message_data)
228+
229+
except ClientError as e:
230+
raise SessionException(f"S3 error reading message: {e}") from e
193231

194232
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
195233
"""Update message data in S3."""
@@ -200,7 +238,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
200238

201239
# Preserve creation timestamp
202240
session_message.created_at = previous_message.created_at
203-
message_key = self._get_message_path(session_id, agent_id, message_id)
241+
message_key = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
204242
self._write_s3_object(message_key, asdict(session_message))
205243

206244
def list_messages(
@@ -212,24 +250,29 @@ def list_messages(
212250
paginator = self.client.get_paginator("list_objects_v2")
213251
pages = paginator.paginate(Bucket=self.bucket, Prefix=messages_prefix)
214252

215-
# Read all message objects
216-
messages: List[SessionMessage] = []
253+
# Collect all message keys first
254+
message_keys = []
217255
for page in pages:
218256
if "Contents" in page:
219257
for obj in page["Contents"]:
220-
if obj["Key"].endswith(".json"):
221-
message_data = self._read_s3_object(obj["Key"])
222-
if message_data:
223-
messages.append(SessionMessage.from_dict(message_data))
258+
if obj["Key"].endswith(".json") and MESSAGE_PREFIX in obj["Key"]:
259+
message_keys.append(obj["Key"])
224260

225-
# Sort by created_at timestamp (oldest first)
226-
messages.sort(key=lambda x: x.created_at)
261+
# Sort keys - timestamp prefixed keys will sort chronologically
262+
message_keys.sort()
227263

228-
# Apply pagination
264+
# Apply pagination to keys before loading content
229265
if limit is not None:
230-
messages = messages[offset : offset + limit]
266+
message_keys = message_keys[offset : offset + limit]
231267
else:
232-
messages = messages[offset:]
268+
message_keys = message_keys[offset:]
269+
270+
# Load only the required message objects
271+
messages: List[SessionMessage] = []
272+
for key in message_keys:
273+
message_data = self._read_s3_object(key)
274+
if message_data:
275+
messages.append(SessionMessage.from_dict(message_data))
233276

234277
return messages
235278

0 commit comments

Comments
 (0)