|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import logging |
| 7 | +import threading |
| 8 | +import time |
7 | 9 |
|
8 | | -from strands import Agent |
| 10 | +from strands import Agent, tool |
9 | 11 |
|
10 | 12 | logging.getLogger("strands").setLevel(logging.DEBUG) |
11 | 13 | logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) |
12 | 14 |
|
13 | 15 |
|
| 16 | +@tool |
| 17 | +def wait(seconds: int) -> None: |
| 18 | + """Waits x seconds based on the user input. |
| 19 | + Seconds - seconds to wait""" |
| 20 | + time.sleep(seconds) |
| 21 | + |
| 22 | + |
14 | 23 | class ToolCountingCallbackHandler: |
15 | 24 | def __init__(self): |
16 | 25 | self.tool_count = 0 |
@@ -68,3 +77,58 @@ def test_basic_interaction(): |
68 | 77 | agent("Tell me a short joke from your general knowledge") |
69 | 78 |
|
70 | 79 | print("\nBasic Interaction Complete") |
| 80 | + |
| 81 | + |
| 82 | +def test_parallel_async_interaction(): |
| 83 | + """Test that concurrent agent invocations are not allowed""" |
| 84 | + |
| 85 | + # Initialize agent |
| 86 | + agent = Agent( |
| 87 | + callback_handler=ToolCountingCallbackHandler().callback_handler, load_tools_from_directory=False, tools=[wait] |
| 88 | + ) |
| 89 | + |
| 90 | + # Track results from both threads |
| 91 | + results = {"thread1": None, "thread2": None, "exception": None} |
| 92 | + |
| 93 | + def invoke_agent_1(): |
| 94 | + """First invocation - should succeed""" |
| 95 | + try: |
| 96 | + result = agent("wait 5 seconds") |
| 97 | + results["thread1"] = result |
| 98 | + except Exception as e: |
| 99 | + results["thread1"] = e |
| 100 | + |
| 101 | + def invoke_agent_2(): |
| 102 | + """Second invocation - should fail with exception""" |
| 103 | + try: |
| 104 | + result = agent("wait 5 seconds") |
| 105 | + results["thread2"] = result |
| 106 | + except Exception as e: |
| 107 | + results["thread2"] = e |
| 108 | + results["exception"] = e |
| 109 | + |
| 110 | + # Start first invocation |
| 111 | + thread1 = threading.Thread(target=invoke_agent_1) |
| 112 | + thread1.start() |
| 113 | + |
| 114 | + # Give it time to start and begin waiting |
| 115 | + time.sleep(1) |
| 116 | + |
| 117 | + # Try second invocation while first is still running |
| 118 | + thread2 = threading.Thread(target=invoke_agent_2) |
| 119 | + thread2.start() |
| 120 | + |
| 121 | + thread1.join() |
| 122 | + thread2.join() |
| 123 | + |
| 124 | + # Assertions |
| 125 | + assert results["thread1"] is not None, "First invocation should complete" |
| 126 | + assert not isinstance(results["thread1"], Exception), "First invocation should succeed" |
| 127 | + |
| 128 | + assert results["exception"] is not None, "Second invocation should throw exception" |
| 129 | + assert isinstance(results["thread2"], Exception), "Second invocation should fail" |
| 130 | + |
| 131 | + expected_message = "Agent is already processing a request. Concurrent invocations are not supported" |
| 132 | + assert expected_message in str(results["thread2"]), ( |
| 133 | + f"Exception message should contain '{expected_message}', but got: {str(results['thread2'])}" |
| 134 | + ) |
0 commit comments