|
15 | 15 | Run: pytest tests/test_agent.py -v |
16 | 16 | """ |
17 | 17 |
|
| 18 | +import asyncio |
| 19 | + |
18 | 20 | import pytest |
| 21 | +import pytest_asyncio |
19 | 22 |
|
20 | | -from agentex.lib.testing import async_test_agent, assert_valid_agent_response |
| 23 | +from agentex.lib.testing import async_test_agent, stream_agent_response, assert_valid_agent_response |
| 24 | +from agentex.lib.testing.sessions import AsyncAgentTest |
| 25 | +from agentex.types.agent_rpc_result import StreamTaskMessageDone, StreamTaskMessageFull |
21 | 26 |
|
22 | 27 | AGENT_NAME = "at010-agent-chat" |
23 | 28 |
|
24 | 29 |
|
25 | | -@pytest.mark.asyncio |
26 | | -async def test_agent_basic(): |
27 | | - """Test basic agent functionality.""" |
28 | | - async with async_test_agent(agent_name=AGENT_NAME) as test: |
29 | | - response = await test.send_event("Test message", timeout_seconds=60.0) |
30 | | - assert_valid_agent_response(response) |
| 30 | +@pytest.fixture |
| 31 | +def agent_name(): |
| 32 | + """Return the agent name for testing.""" |
| 33 | + return AGENT_NAME |
31 | 34 |
|
32 | 35 |
|
33 | | -@pytest.mark.asyncio |
34 | | -async def test_agent_streaming(): |
35 | | - """Test streaming responses.""" |
36 | | - async with async_test_agent(agent_name=AGENT_NAME) as test: |
37 | | - events = [] |
38 | | - async for event in test.send_event_and_stream("Stream test", timeout_seconds=60.0): |
39 | | - events.append(event) |
40 | | - if event.get("type") == "done": |
41 | | - break |
42 | | - assert len(events) > 0 |
| 36 | +@pytest_asyncio.fixture |
| 37 | +async def test_agent(agent_name: str): |
| 38 | + """Fixture to create a test async agent.""" |
| 39 | + async with async_test_agent(agent_name=agent_name) as test: |
| 40 | + yield test |
| 41 | + |
| 42 | +class TestNonStreamingEvents: |
| 43 | + """Test non-streaming event sending and polling with OpenAI Agents SDK.""" |
43 | 44 |
|
44 | 45 | @pytest.mark.asyncio |
45 | | - async def test_send_event_and_poll_with_calculator(self, client: AsyncAgentex, agent_id: str): |
46 | | - """Test sending an event that triggers calculator tool usage and polling for the response.""" |
47 | | - # Create a task for this conversation |
48 | | - task_response = await client.agents.create_task(agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex)) |
49 | | - task = task_response.result |
50 | | - assert task is not None |
| 46 | + async def test_send_event_and_poll_simple_query(self, test_agent: AsyncAgentTest): |
| 47 | + """Test basic agent functionality.""" |
| 48 | + # Wait for state initialization |
| 49 | + await asyncio.sleep(1) |
| 50 | + |
| 51 | + # Send a simple message that shouldn't require tool use |
| 52 | + response = await test_agent.send_event("Hello! Please introduce yourself briefly.", timeout_seconds=30.0) |
| 53 | + assert_valid_agent_response(response) |
51 | 54 |
|
| 55 | + @pytest.mark.asyncio |
| 56 | + async def test_send_event_and_poll_with_calculator(self, test_agent: AsyncAgentTest): |
| 57 | + """Test sending an event that triggers calculator tool usage and polling for the response.""" |
52 | 58 | # Wait for workflow to initialize |
53 | 59 | await asyncio.sleep(1) |
54 | 60 |
|
55 | 61 | # Send a message that could trigger the calculator tool (though with reasoning, it may not need it) |
56 | 62 | user_message = "What is 15 multiplied by 37?" |
57 | | - has_final_agent_response = False |
58 | | - |
59 | | - async for message in send_event_and_poll_yielding( |
60 | | - client=client, |
61 | | - agent_id=agent_id, |
62 | | - task_id=task.id, |
63 | | - user_message=user_message, |
64 | | - timeout=60, # Longer timeout for tool use |
65 | | - sleep_interval=1.0, |
66 | | - ): |
67 | | - assert isinstance(message, TaskMessage) |
68 | | - if message.content and message.content.type == "text" and message.content.author == "agent": |
69 | | - # Check that the answer contains 555 (15 * 37) |
70 | | - if "555" in message.content.content: |
71 | | - has_final_agent_response = True |
72 | | - break |
73 | | - |
74 | | - assert has_final_agent_response, "Did not receive final agent text response with correct answer" |
| 63 | + response = await test_agent.send_event(user_message, timeout_seconds=60.0) |
| 64 | + assert_valid_agent_response(response) |
| 65 | + assert "555" in response.content, "Expected answer '555' not found in agent response" |
75 | 66 |
|
76 | 67 | @pytest.mark.asyncio |
77 | | - async def test_multi_turn_conversation(self, client: AsyncAgentex, agent_id: str): |
| 68 | + async def test_multi_turn_conversation_with_state(self, test_agent: AsyncAgentTest): |
78 | 69 | """Test multiple turns of conversation with state preservation.""" |
79 | | - # Create a task for this conversation |
80 | | - task_response = await client.agents.create_task(agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex)) |
81 | | - task = task_response.result |
82 | | - assert task is not None |
83 | | - |
84 | 70 | # Wait for workflow to initialize |
85 | 71 | await asyncio.sleep(1) |
86 | 72 |
|
87 | | - # First turn |
88 | | - user_message_1 = "My favorite color is blue." |
89 | | - async for message in send_event_and_poll_yielding( |
90 | | - client=client, |
91 | | - agent_id=agent_id, |
92 | | - task_id=task.id, |
93 | | - user_message=user_message_1, |
94 | | - timeout=20, |
95 | | - sleep_interval=1.0, |
96 | | - ): |
97 | | - assert isinstance(message, TaskMessage) |
98 | | - if ( |
99 | | - message.content |
100 | | - and message.content.type == "text" |
101 | | - and message.content.author == "agent" |
102 | | - and message.content.content |
103 | | - ): |
104 | | - break |
105 | | - |
106 | | - # Wait a bit for state to update |
107 | | - await asyncio.sleep(2) |
108 | | - |
109 | | - # Second turn - reference previous context |
110 | | - found_response = False |
111 | | - user_message_2 = "What did I just tell you my favorite color was?" |
112 | | - async for message in send_event_and_poll_yielding( |
113 | | - client=client, |
114 | | - agent_id=agent_id, |
115 | | - task_id=task.id, |
116 | | - user_message=user_message_2, |
117 | | - timeout=30, |
118 | | - sleep_interval=1.0, |
119 | | - ): |
120 | | - if ( |
121 | | - message.content |
122 | | - and message.content.type == "text" |
123 | | - and message.content.author == "agent" |
124 | | - and message.content.content |
125 | | - ): |
126 | | - response_text = message.content.content.lower() |
127 | | - assert "blue" in response_text, f"Expected 'blue' in response but got: {response_text}" |
128 | | - found_response = True |
129 | | - break |
130 | | - |
131 | | - assert found_response, "Did not receive final agent text response with context recall" |
| 73 | + response = await test_agent.send_event("My favorite color is blue", timeout_seconds=30.0) |
| 74 | + assert_valid_agent_response(response) |
| 75 | + |
| 76 | + second_response = await test_agent.send_event( |
| 77 | + "What did I just tell you my favorite color was?", timeout_seconds=30.0 |
| 78 | + ) |
| 79 | + assert_valid_agent_response(second_response) |
| 80 | + assert "blue" in second_response.content.lower() |
132 | 81 |
|
133 | 82 |
|
134 | 83 | class TestStreamingEvents: |
135 | 84 | """Test streaming event sending with OpenAI Agents SDK and tool usage.""" |
136 | 85 |
|
137 | 86 | @pytest.mark.asyncio |
138 | | - async def test_send_event_and_stream_with_reasoning(self, client: AsyncAgentex, agent_id: str): |
139 | | - """Test streaming a simple response without tool usage.""" |
140 | | - # Create a task for this conversation |
141 | | - task_response = await client.agents.create_task(agent_id, params=ParamsCreateTaskRequest(name=uuid.uuid1().hex)) |
142 | | - task = task_response.result |
143 | | - assert task is not None |
144 | | - |
| 87 | + async def test_send_event_and_stream_with_reasoning(self, test_agent: AsyncAgentTest): |
| 88 | + """Test streaming event responses.""" |
145 | 89 | # Wait for workflow to initialize |
146 | 90 | await asyncio.sleep(1) |
147 | 91 |
|
| 92 | + # Send message and stream response |
148 | 93 | user_message = "Tell me a very short joke about programming." |
149 | 94 |
|
150 | 95 | # Check for user message and agent response |
151 | 96 | user_message_found = False |
152 | 97 | agent_response_found = False |
153 | 98 |
|
154 | | - async def stream_messages() -> None: # noqa: ANN101 |
155 | | - nonlocal user_message_found, agent_response_found |
156 | | - async for event in stream_agent_response( |
157 | | - client=client, |
158 | | - task_id=task.id, |
159 | | - timeout=60, |
160 | | - ): |
161 | | - msg_type = event.get("type") |
162 | | - if msg_type == "full": |
163 | | - task_message_update = StreamTaskMessageFull.model_validate(event) |
164 | | - if task_message_update.parent_task_message and task_message_update.parent_task_message.id: |
165 | | - finished_message = await client.messages.retrieve(task_message_update.parent_task_message.id) |
166 | | - if ( |
167 | | - finished_message.content |
168 | | - and finished_message.content.type == "text" |
169 | | - and finished_message.content.author == "user" |
170 | | - ): |
171 | | - user_message_found = True |
172 | | - elif ( |
173 | | - finished_message.content |
174 | | - and finished_message.content.type == "text" |
175 | | - and finished_message.content.author == "agent" |
176 | | - ): |
177 | | - agent_response_found = True |
178 | | - elif finished_message.content and finished_message.content.type == "reasoning": |
179 | | - tool_response_found = True |
180 | | - elif msg_type == "done": |
181 | | - task_message_update = StreamTaskMessageDone.model_validate(event) |
182 | | - if task_message_update.parent_task_message and task_message_update.parent_task_message.id: |
183 | | - finished_message = await client.messages.retrieve(task_message_update.parent_task_message.id) |
184 | | - if finished_message.content and finished_message.content.type == "reasoning": |
185 | | - agent_response_found = True |
186 | | - continue |
187 | | - |
188 | | - stream_task = asyncio.create_task(stream_messages()) |
189 | | - |
190 | | - event_content = TextContentParam(type="text", author="user", content=user_message) |
191 | | - await client.agents.send_event(agent_id=agent_id, params={"task_id": task.id, "content": event_content}) |
192 | | - |
193 | | - # Wait for streaming to complete |
194 | | - await stream_task |
| 99 | + # Stream events |
| 100 | + async for event in stream_agent_response(test_agent.client, test_agent.task_id, timeout=60.0): |
| 101 | + event_type = event.get("type") |
| 102 | + |
| 103 | + if event_type == "connected": |
| 104 | + await test_agent.send_event(user_message, timeout_seconds=30.0) |
| 105 | + |
| 106 | + elif event_type == "full": |
| 107 | + print('full event', event) |
| 108 | + task_message_update = StreamTaskMessageFull.model_validate(event) |
| 109 | + if task_message_update.parent_task_message and task_message_update.parent_task_message.id: |
| 110 | + finished_message = await test_agent.client.messages.retrieve(task_message_update.parent_task_message.id) |
| 111 | + if ( |
| 112 | + finished_message.content |
| 113 | + and finished_message.content.type == "text" |
| 114 | + and finished_message.content.author == "user" |
| 115 | + ): |
| 116 | + user_message_found = True |
| 117 | + elif ( |
| 118 | + finished_message.content |
| 119 | + and finished_message.content.type == "text" |
| 120 | + and finished_message.content.author == "agent" |
| 121 | + ): |
| 122 | + agent_response_found = True |
| 123 | + elif event_type == "done": |
| 124 | + print('done event', event) |
| 125 | + task_message_update = StreamTaskMessageDone.model_validate(event) |
| 126 | + if task_message_update.parent_task_message and task_message_update.parent_task_message.id: |
| 127 | + finished_message = await test_agent.client.messages.retrieve(task_message_update.parent_task_message.id) |
| 128 | + if finished_message.content and finished_message.content.type == "text" and finished_message.content.author == "agent": |
| 129 | + agent_response_found = True |
| 130 | + continue |
195 | 131 |
|
196 | 132 | assert user_message_found, "User message not found in stream" |
197 | 133 | assert agent_response_found, "Agent response not found in stream" |
|
0 commit comments