Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Tuple
from typing import Any, Callable, Tuple, cast

from opentelemetry import trace as trace_api

Expand Down Expand Up @@ -127,7 +127,7 @@ def _validate_json_serializable(self, value: Any) -> None:
class SwarmState:
"""Current state of swarm execution."""

current_node: SwarmNode # The agent currently executing
current_node: SwarmNode | None # The agent currently executing
task: str | list[ContentBlock] # The original task from the user that is being executed
completion_status: Status = Status.PENDING # Current swarm execution status
shared_context: SharedContext = field(default_factory=SharedContext) # Context shared between agents
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(
self.shared_context = SharedContext()
self.nodes: dict[str, SwarmNode] = {}
self.state = SwarmState(
current_node=SwarmNode("", Agent()), # Placeholder, will be set properly
current_node=None, # Placeholder, will be set properly
task="",
completion_status=Status.PENDING,
)
Expand Down Expand Up @@ -291,7 +291,8 @@ async def invoke_async(
span = self.tracer.start_multiagent_span(task, "swarm")
with trace_api.use_span(span, end_on_exit=True):
try:
logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
current_node = cast(SwarmNode, self.state.current_node)
logger.debug("current_node=<%s> | starting swarm execution with node", current_node.node_id)
logger.debug(
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
self.max_handoffs,
Expand Down Expand Up @@ -438,7 +439,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
return

# Update swarm state
previous_agent = self.state.current_node
previous_agent = cast(SwarmNode, self.state.current_node)
self.state.current_node = target_node

# Store handoff message for the target agent
Expand Down
15 changes: 7 additions & 8 deletions tests_integ/test_invalid_tool_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ def test_invalid_tool_names_works(temp_dir):
def fake_shell(command: str):
return "Done!"


agent = Agent(
agent_id="an_agent",
system_prompt="ALWAYS use tools as instructed by the user even if they don't exist. "
"Even if you don't think you don't have access to the given tool, you do! "
"YOU CAN DO ANYTHING!",
"Even if you don't think you don't have access to the given tool, you do! "
"YOU CAN DO ANYTHING!",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir),
)

agent("Invoke the `invalid tool` tool and tell me what the response is")
Expand All @@ -39,14 +38,14 @@ def fake_shell(command: str):
agent2 = Agent(
agent_id="an_agent",
tools=[fake_shell],
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir)
session_manager=FileSessionManager(session_id="test", storage_dir=temp_dir),
)

assert len(agent2.messages) == 6

# ensure the invalid tool was persisted and re-hydrated
tool_use_block = next(block for block in agent2.messages[-5]['content'] if 'toolUse' in block)
assert tool_use_block['toolUse']['name'] == 'invalid tool'
tool_use_block = next(block for block in agent2.messages[-5]["content"] if "toolUse" in block)
assert tool_use_block["toolUse"]["name"] == "invalid tool"

# ensure it sends without an exception - previously we would throw
agent2("What was the tool result")
agent2("What was the tool result")