|
22 | 22 | from google.cloud.aiplatform import initializer
|
23 | 23 | from vertexai.preview import reasoning_engines
|
24 | 24 | from vertexai.agent_engines import _utils
|
| 25 | + |
| 26 | +import asyncio |
25 | 27 | import pytest
|
26 | 28 |
|
27 | 29 |
|
@@ -111,6 +113,33 @@ def run(self, *args, **kwargs):
|
111 | 113 | }
|
112 | 114 | )
|
113 | 115 |
|
| 116 | + async def run_async(self, *args, **kwargs): |
| 117 | + from google.adk.events import event |
| 118 | + |
| 119 | + yield event.Event( |
| 120 | + **{ |
| 121 | + "author": "currency_exchange_agent", |
| 122 | + "content": { |
| 123 | + "parts": [ |
| 124 | + { |
| 125 | + "function_call": { |
| 126 | + "args": { |
| 127 | + "currency_date": "2025-04-03", |
| 128 | + "currency_from": "USD", |
| 129 | + "currency_to": "SEK", |
| 130 | + }, |
| 131 | + "id": "af-c5a57692-9177-4091-a3df-098f834ee849", |
| 132 | + "name": "get_exchange_rate", |
| 133 | + } |
| 134 | + } |
| 135 | + ], |
| 136 | + "role": "model", |
| 137 | + }, |
| 138 | + "id": "9aaItGK9", |
| 139 | + "invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7", |
| 140 | + } |
| 141 | + ) |
| 142 | + |
114 | 143 |
|
115 | 144 | @pytest.mark.usefixtures("google_auth_mock")
|
116 | 145 | class TestAdkApp:
|
@@ -195,6 +224,45 @@ def test_stream_query_with_content(self):
|
195 | 224 | )
|
196 | 225 | assert len(events) == 1
|
197 | 226 |
|
| 227 | + @pytest.mark.asyncio |
| 228 | + async def test_async_stream_query(self): |
| 229 | + app = reasoning_engines.AdkApp( |
| 230 | + agent=Agent(name="test_agent", model=_TEST_MODEL) |
| 231 | + ) |
| 232 | + assert app._tmpl_attrs.get("runner") is None |
| 233 | + app.set_up() |
| 234 | + app._tmpl_attrs["runner"] = _MockRunner() |
| 235 | + events = [] |
| 236 | + async for event in app.async_stream_query( |
| 237 | + user_id="test_user_id", |
| 238 | + message="test message", |
| 239 | + ): |
| 240 | + events.append(event) |
| 241 | + assert len(events) == 1 |
| 242 | + |
| 243 | + @pytest.mark.asyncio |
| 244 | + async def test_async_stream_query_with_content(self): |
| 245 | + app = reasoning_engines.AdkApp( |
| 246 | + agent=Agent(name="test_agent", model=_TEST_MODEL) |
| 247 | + ) |
| 248 | + assert app._tmpl_attrs.get("runner") is None |
| 249 | + app.set_up() |
| 250 | + app._tmpl_attrs["runner"] = _MockRunner() |
| 251 | + events = [] |
| 252 | + async for event in app.async_stream_query( |
| 253 | + user_id="test_user_id", |
| 254 | + message=types.Content( |
| 255 | + role="user", |
| 256 | + parts=[ |
| 257 | + types.Part( |
| 258 | + text="test message with content", |
| 259 | + ) |
| 260 | + ], |
| 261 | + ).model_dump(), |
| 262 | + ): |
| 263 | + events.append(event) |
| 264 | + assert len(events) == 1 |
| 265 | + |
198 | 266 | def test_streaming_agent_run_with_events(self):
|
199 | 267 | app = reasoning_engines.AdkApp(
|
200 | 268 | agent=Agent(name="test_agent", model=_TEST_MODEL)
|
@@ -322,3 +390,25 @@ def test_raise_get_session_not_found_error(self):
|
322 | 390 | user_id="non_existent_user",
|
323 | 391 | session_id="test_session_id",
|
324 | 392 | )
|
| 393 | + |
| 394 | + def test_stream_query_invalid_message_type(self): |
| 395 | + app = reasoning_engines.AdkApp( |
| 396 | + agent=Agent(name="test_agent", model=_TEST_MODEL) |
| 397 | + ) |
| 398 | + with pytest.raises( |
| 399 | + TypeError, |
| 400 | + match="message must be a string or a dictionary representing a Content object.", |
| 401 | + ): |
| 402 | + list(app.stream_query(user_id="test_user_id", message=123)) |
| 403 | + |
| 404 | + @pytest.mark.asyncio |
| 405 | + async def test_async_stream_query_invalid_message_type(self): |
| 406 | + app = reasoning_engines.AdkApp( |
| 407 | + agent=Agent(name="test_agent", model=_TEST_MODEL) |
| 408 | + ) |
| 409 | + with pytest.raises( |
| 410 | + TypeError, |
| 411 | + match="message must be a string or a dictionary representing a Content object.", |
| 412 | + ): |
| 413 | + async for _ in app.async_stream_query(user_id="test_user_id", message=123): |
| 414 | + pass |
0 commit comments