Skip to content

Commit 9602938

Browse files
committed
fix(concurrent-invocations): added protection from concurrent invocations to the same agent instance
1 parent b4efc9d commit 9602938

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

src/strands/agent/agent.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
1010
"""
1111

12+
import asyncio
1213
import json
1314
import logging
1415
import random
@@ -293,6 +294,7 @@ def __init__(
293294
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
294295
self.name = name or _DEFAULT_AGENT_NAME
295296
self.description = description
297+
self._invocation_lock = asyncio.Lock()
296298

297299
# If not provided, create a new PrintingCallbackHandler instance
298300
# If explicitly set to None, use null_callback_handler
@@ -494,13 +496,17 @@ async def invoke_async(
494496
- metrics: Performance metrics from the event loop
495497
- state: The final state of the event loop
496498
"""
497-
events = self.stream_async(
498-
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
499-
)
500-
async for event in events:
501-
_ = event
499+
if self._invocation_lock.locked():
500+
raise RuntimeError("Agent is already processing a request. Concurrent invocations are not supported.")
501+
502+
async with self._invocation_lock:
503+
events = self.stream_async(
504+
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
505+
)
506+
async for event in events:
507+
_ = event
502508

503-
return cast(AgentResult, event["result"])
509+
return cast(AgentResult, event["result"])
504510

505511
def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
506512
"""This method allows you to get structured output from the agent.

tests/strands/agent/test_agent.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,16 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator):
783783
assert tru_message == exp_message
784784

785785

786+
@pytest.mark.asyncio
787+
async def test_agent_parallel_invocations():
788+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}])
789+
agent = Agent(model=model)
790+
791+
async with agent._invocation_lock:
792+
with pytest.raises(RuntimeError, match="Concurrent invocations are not supported"):
793+
await agent.invoke_async("test")
794+
795+
786796
@pytest.mark.asyncio
787797
async def test_agent_invoke_async(mock_model, agent, agenerator):
788798
mock_model.mock_stream.return_value = agenerator(

tests_integ/test_stream_agent.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44
"""
55

66
import logging
7+
import threading
8+
import time
79

8-
from strands import Agent
10+
from strands import Agent, tool
911

1012
logging.getLogger("strands").setLevel(logging.DEBUG)
1113
logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()])
1214

1315

16+
@tool
17+
def wait(seconds: int) -> None:
18+
"""Waits x seconds based on the user input.
19+
Seconds - seconds to wait"""
20+
time.sleep(seconds)
21+
22+
1423
class ToolCountingCallbackHandler:
1524
def __init__(self):
1625
self.tool_count = 0
@@ -68,3 +77,58 @@ def test_basic_interaction():
6877
agent("Tell me a short joke from your general knowledge")
6978

7079
print("\nBasic Interaction Complete")
80+
81+
82+
def test_parallel_async_interaction():
83+
"""Test that concurrent agent invocations are not allowed"""
84+
85+
# Initialize agent
86+
agent = Agent(
87+
callback_handler=ToolCountingCallbackHandler().callback_handler, load_tools_from_directory=False, tools=[wait]
88+
)
89+
90+
# Track results from both threads
91+
results = {"thread1": None, "thread2": None, "exception": None}
92+
93+
def invoke_agent_1():
94+
"""First invocation - should succeed"""
95+
try:
96+
result = agent("wait 5 seconds")
97+
results["thread1"] = result
98+
except Exception as e:
99+
results["thread1"] = e
100+
101+
def invoke_agent_2():
102+
"""Second invocation - should fail with exception"""
103+
try:
104+
result = agent("wait 5 seconds")
105+
results["thread2"] = result
106+
except Exception as e:
107+
results["thread2"] = e
108+
results["exception"] = e
109+
110+
# Start first invocation
111+
thread1 = threading.Thread(target=invoke_agent_1)
112+
thread1.start()
113+
114+
# Give it time to start and begin waiting
115+
time.sleep(1)
116+
117+
# Try second invocation while first is still running
118+
thread2 = threading.Thread(target=invoke_agent_2)
119+
thread2.start()
120+
121+
thread1.join()
122+
thread2.join()
123+
124+
# Assertions
125+
assert results["thread1"] is not None, "First invocation should complete"
126+
assert not isinstance(results["thread1"], Exception), "First invocation should succeed"
127+
128+
assert results["exception"] is not None, "Second invocation should throw exception"
129+
assert isinstance(results["thread2"], Exception), "Second invocation should fail"
130+
131+
expected_message = "Agent is already processing a request. Concurrent invocations are not supported"
132+
assert expected_message in str(results["thread2"]), (
133+
f"Exception message should contain '{expected_message}', but got: {str(results['thread2'])}"
134+
)

0 commit comments

Comments
 (0)