Skip to content

Commit 053f483

Browse files
authored
Merge branch 'main' into multiagent-streaming
2 parents 82fdd4f + 95906fa commit 053f483

File tree

14 files changed

+554
-57
lines changed

14 files changed

+554
-57
lines changed

src/strands/models/openai.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,16 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str
214214
for message in messages:
215215
contents = message["content"]
216216

217+
# Check for reasoningContent and warn user
218+
if any("reasoningContent" in content for content in contents):
219+
logger.warning(
220+
"reasoningContent is not supported in multi-turn conversations with the Chat Completions API."
221+
)
222+
217223
formatted_contents = [
218224
cls.format_request_message_content(content)
219225
for content in contents
220-
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
226+
if not any(block_type in content for block_type in ["toolResult", "toolUse", "reasoningContent"])
221227
]
222228
formatted_tool_calls = [
223229
cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content
@@ -405,38 +411,46 @@ async def stream(
405411

406412
logger.debug("got response from model")
407413
yield self.format_chunk({"chunk_type": "message_start"})
408-
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})
409-
410414
tool_calls: dict[int, list[Any]] = {}
415+
data_type = None
416+
finish_reason = None # Store finish_reason for later use
417+
event = None # Initialize for scope safety
411418

412419
async for event in response:
413420
# Defensive: skip events with empty or missing choices
414421
if not getattr(event, "choices", None):
415422
continue
416423
choice = event.choices[0]
417424

418-
if choice.delta.content:
419-
yield self.format_chunk(
420-
{"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
421-
)
422-
423425
if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
426+
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
427+
for chunk in chunks:
428+
yield chunk
424429
yield self.format_chunk(
425430
{
426431
"chunk_type": "content_delta",
427-
"data_type": "reasoning_content",
432+
"data_type": data_type,
428433
"data": choice.delta.reasoning_content,
429434
}
430435
)
431436

437+
if choice.delta.content:
438+
chunks, data_type = self._stream_switch_content("text", data_type)
439+
for chunk in chunks:
440+
yield chunk
441+
yield self.format_chunk(
442+
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
443+
)
444+
432445
for tool_call in choice.delta.tool_calls or []:
433446
tool_calls.setdefault(tool_call.index, []).append(tool_call)
434447

435448
if choice.finish_reason:
449+
finish_reason = choice.finish_reason # Store for use outside loop
450+
if data_type:
451+
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
436452
break
437453

438-
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
439-
440454
for tool_deltas in tool_calls.values():
441455
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})
442456

@@ -445,17 +459,37 @@ async def stream(
445459

446460
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})
447461

448-
yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})
462+
yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason or "end_turn"})
449463

450464
# Skip remaining events as we don't have use for anything except the final usage payload
451465
async for event in response:
452466
_ = event
453467

454-
if event.usage:
468+
if event and hasattr(event, "usage") and event.usage:
455469
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})
456470

457471
logger.debug("finished streaming response from model")
458472

473+
def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> tuple[list[StreamEvent], str]:
474+
"""Handle switching to a new content stream.
475+
476+
Args:
477+
data_type: The next content data type.
478+
prev_data_type: The previous content data type.
479+
480+
Returns:
481+
Tuple containing:
482+
- Stop block for previous content and the start block for the next content.
483+
- Next content data type.
484+
"""
485+
chunks = []
486+
if data_type != prev_data_type:
487+
if prev_data_type is not None:
488+
chunks.append(self.format_chunk({"chunk_type": "content_stop", "data_type": prev_data_type}))
489+
chunks.append(self.format_chunk({"chunk_type": "content_start", "data_type": data_type}))
490+
491+
return chunks, data_type
492+
459493
@override
460494
async def structured_output(
461495
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any

src/strands/multiagent/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult":
137137
metrics = _parse_metrics(data.get("accumulated_metrics", {}))
138138

139139
multiagent_result = cls(
140-
status=Status(data.get("status", Status.PENDING.value)),
140+
status=Status(data["status"]),
141141
results=results,
142142
accumulated_usage=usage,
143143
accumulated_metrics=metrics,
@@ -164,8 +164,13 @@ class MultiAgentBase(ABC):
164164
165165
This class integrates with existing Strands Agent instances and provides
166166
multi-agent orchestration capabilities.
167+
168+
Attributes:
169+
id: Unique MultiAgent id for session management,etc.
167170
"""
168171

172+
id: str
173+
169174
@abstractmethod
170175
async def invoke_async(
171176
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any

src/strands/multiagent/swarm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _validate_json_serializable(self, value: Any) -> None:
133133
class SwarmState:
134134
"""Current state of swarm execution."""
135135

136-
current_node: SwarmNode # The agent currently executing
136+
current_node: SwarmNode | None # The agent currently executing
137137
task: str | list[ContentBlock] # The original task from the user that is being executed
138138
completion_status: Status = Status.PENDING # Current swarm execution status
139139
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
@@ -238,7 +238,7 @@ def __init__(
238238
self.shared_context = SharedContext()
239239
self.nodes: dict[str, SwarmNode] = {}
240240
self.state = SwarmState(
241-
current_node=SwarmNode("", Agent()), # Placeholder, will be set properly
241+
current_node=None, # Placeholder, will be set properly
242242
task="",
243243
completion_status=Status.PENDING,
244244
)
@@ -328,7 +328,8 @@ async def stream_async(
328328
span = self.tracer.start_multiagent_span(task, "swarm")
329329
with trace_api.use_span(span, end_on_exit=True):
330330
try:
331-
logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
331+
current_node = cast(SwarmNode, self.state.current_node)
332+
logger.debug("current_node=<%s> | starting swarm execution with node", current_node.node_id)
332333
logger.debug(
333334
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
334335
self.max_handoffs,
@@ -522,7 +523,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
522523
return
523524

524525
# Update swarm state
525-
previous_agent = self.state.current_node
526+
previous_agent = cast(SwarmNode, self.state.current_node)
526527
self.state.current_node = target_node
527528

528529
# Store handoff message for the target agent

src/strands/session/file_session_manager.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
import os
66
import shutil
77
import tempfile
8-
from typing import Any, Optional, cast
8+
from typing import TYPE_CHECKING, Any, Optional, cast
99

1010
from .. import _identifier
1111
from ..types.exceptions import SessionException
1212
from ..types.session import Session, SessionAgent, SessionMessage
1313
from .repository_session_manager import RepositorySessionManager
1414
from .session_repository import SessionRepository
1515

16+
if TYPE_CHECKING:
17+
from ..multiagent.base import MultiAgentBase
18+
1619
logger = logging.getLogger(__name__)
1720

1821
SESSION_PREFIX = "session_"
1922
AGENT_PREFIX = "agent_"
2023
MESSAGE_PREFIX = "message_"
24+
MULTI_AGENT_PREFIX = "multi_agent_"
2125

2226

2327
class FileSessionManager(RepositorySessionManager, SessionRepository):
@@ -37,7 +41,12 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
3741
```
3842
"""
3943

40-
def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any):
44+
def __init__(
45+
self,
46+
session_id: str,
47+
storage_dir: Optional[str] = None,
48+
**kwargs: Any,
49+
):
4150
"""Initialize FileSession with filesystem storage.
4251
4352
Args:
@@ -107,8 +116,11 @@ def _read_file(self, path: str) -> dict[str, Any]:
107116
def _write_file(self, path: str, data: dict[str, Any]) -> None:
108117
"""Write JSON file."""
109118
os.makedirs(os.path.dirname(path), exist_ok=True)
110-
with open(path, "w", encoding="utf-8") as f:
119+
# This automic write ensure the completeness of session files in both single agent/ multi agents
120+
tmp = f"{path}.tmp"
121+
with open(tmp, "w", encoding="utf-8", newline="\n") as f:
111122
json.dump(data, f, indent=2, ensure_ascii=False)
123+
os.replace(tmp, path)
112124

113125
def create_session(self, session: Session, **kwargs: Any) -> Session:
114126
"""Create a new session."""
@@ -119,6 +131,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:
119131
# Create directory structure
120132
os.makedirs(session_dir, exist_ok=True)
121133
os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True)
134+
os.makedirs(os.path.join(session_dir, "multi_agents"), exist_ok=True)
122135

123136
# Write session file
124137
session_file = os.path.join(session_dir, "session.json")
@@ -239,3 +252,36 @@ def list_messages(
239252
messages.append(SessionMessage.from_dict(message_data))
240253

241254
return messages
255+
256+
def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str:
257+
"""Get multi-agent state file path."""
258+
session_path = self._get_session_path(session_id)
259+
multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT)
260+
return os.path.join(session_path, "multi_agents", f"{MULTI_AGENT_PREFIX}{multi_agent_id}")
261+
262+
def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
263+
"""Create a new multiagent state in the session."""
264+
multi_agent_id = multi_agent.id
265+
multi_agent_dir = self._get_multi_agent_path(session_id, multi_agent_id)
266+
os.makedirs(multi_agent_dir, exist_ok=True)
267+
268+
multi_agent_file = os.path.join(multi_agent_dir, "multi_agent.json")
269+
session_data = multi_agent.serialize_state()
270+
self._write_file(multi_agent_file, session_data)
271+
272+
def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]:
273+
"""Read multi-agent state from filesystem."""
274+
multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json")
275+
if not os.path.exists(multi_agent_file):
276+
return None
277+
return self._read_file(multi_agent_file)
278+
279+
def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
280+
"""Update multi-agent state from filesystem."""
281+
multi_agent_state = multi_agent.serialize_state()
282+
previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id)
283+
if previous_multi_agent_state is None:
284+
raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist")
285+
286+
multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent.id), "multi_agent.json")
287+
self._write_file(multi_agent_file, multi_agent_state)

src/strands/session/repository_session_manager.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@
1717

1818
if TYPE_CHECKING:
1919
from ..agent.agent import Agent
20+
from ..multiagent.base import MultiAgentBase
2021

2122
logger = logging.getLogger(__name__)
2223

2324

2425
class RepositorySessionManager(SessionManager):
2526
"""Session manager for persisting agents in a SessionRepository."""
2627

27-
def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any):
28+
def __init__(
29+
self,
30+
session_id: str,
31+
session_repository: SessionRepository,
32+
**kwargs: Any,
33+
):
2834
"""Initialize the RepositorySessionManager.
2935
3036
If no session with the specified session_id exists yet, it will be created
@@ -152,3 +158,26 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None:
152158

153159
# Restore the agents messages array including the optional prepend messages
154160
agent.messages = prepend_messages + [session_message.to_message() for session_message in session_messages]
161+
162+
def sync_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
163+
"""Serialize and update the multi-agent state into the session repository.
164+
165+
Args:
166+
source: Multi-agent source object to sync to the session.
167+
**kwargs: Additional keyword arguments for future extensibility.
168+
"""
169+
self.session_repository.update_multi_agent(self.session_id, source)
170+
171+
def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> None:
172+
"""Initialize multi-agent state from the session repository.
173+
174+
Args:
175+
source: Multi-agent source object to restore state into
176+
**kwargs: Additional keyword arguments for future extensibility.
177+
"""
178+
state = self.session_repository.read_multi_agent(self.session_id, source.id, **kwargs)
179+
if state is None:
180+
self.session_repository.create_multi_agent(self.session_id, source, **kwargs)
181+
else:
182+
logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id)
183+
source.deserialize_state(state)

src/strands/session/s3_session_manager.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import Any, Dict, List, Optional, cast
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
66

77
import boto3
88
from botocore.config import Config as BotocoreConfig
@@ -14,11 +14,15 @@
1414
from .repository_session_manager import RepositorySessionManager
1515
from .session_repository import SessionRepository
1616

17+
if TYPE_CHECKING:
18+
from ..multiagent.base import MultiAgentBase
19+
1720
logger = logging.getLogger(__name__)
1821

1922
SESSION_PREFIX = "session_"
2023
AGENT_PREFIX = "agent_"
2124
MESSAGE_PREFIX = "message_"
25+
MULTI_AGENT_PREFIX = "multi_agent_"
2226

2327

2428
class S3SessionManager(RepositorySessionManager, SessionRepository):
@@ -294,3 +298,31 @@ def list_messages(
294298

295299
except ClientError as e:
296300
raise SessionException(f"S3 error reading messages: {e}") from e
301+
302+
def _get_multi_agent_path(self, session_id: str, multi_agent_id: str) -> str:
303+
"""Get multi-agent S3 prefix."""
304+
session_path = self._get_session_path(session_id)
305+
multi_agent_id = _identifier.validate(multi_agent_id, _identifier.Identifier.AGENT)
306+
return f"{session_path}multi_agents/{MULTI_AGENT_PREFIX}{multi_agent_id}/"
307+
308+
def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
309+
"""Create a new multiagent state in S3."""
310+
multi_agent_id = multi_agent.id
311+
multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json"
312+
session_data = multi_agent.serialize_state()
313+
self._write_s3_object(multi_agent_key, session_data)
314+
315+
def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]:
316+
"""Read multi-agent state from S3."""
317+
multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json"
318+
return self._read_s3_object(multi_agent_key)
319+
320+
def update_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **kwargs: Any) -> None:
321+
"""Update multi-agent state in S3."""
322+
multi_agent_state = multi_agent.serialize_state()
323+
previous_multi_agent_state = self.read_multi_agent(session_id=session_id, multi_agent_id=multi_agent.id)
324+
if previous_multi_agent_state is None:
325+
raise SessionException(f"MultiAgent state {multi_agent.id} in session {session_id} does not exist")
326+
327+
multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent.id)}multi_agent.json"
328+
self._write_s3_object(multi_agent_key, multi_agent_state)

0 commit comments

Comments
 (0)