Skip to content

Commit bef115c

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Add async stream support to agent engines
PiperOrigin-RevId: 756896866
1 parent 03a7861 commit bef115c

File tree

5 files changed

+634
-19
lines changed

5 files changed

+634
-19
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

+90
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from google.cloud.aiplatform import initializer
2323
from vertexai.preview import reasoning_engines
2424
from vertexai.agent_engines import _utils
25+
26+
import asyncio
2527
import pytest
2628

2729

@@ -111,6 +113,33 @@ def run(self, *args, **kwargs):
111113
}
112114
)
113115

116+
async def run_async(self, *args, **kwargs):
117+
from google.adk.events import event
118+
119+
yield event.Event(
120+
**{
121+
"author": "currency_exchange_agent",
122+
"content": {
123+
"parts": [
124+
{
125+
"function_call": {
126+
"args": {
127+
"currency_date": "2025-04-03",
128+
"currency_from": "USD",
129+
"currency_to": "SEK",
130+
},
131+
"id": "af-c5a57692-9177-4091-a3df-098f834ee849",
132+
"name": "get_exchange_rate",
133+
}
134+
}
135+
],
136+
"role": "model",
137+
},
138+
"id": "9aaItGK9",
139+
"invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7",
140+
}
141+
)
142+
114143

115144
@pytest.mark.usefixtures("google_auth_mock")
116145
class TestAdkApp:
@@ -195,6 +224,45 @@ def test_stream_query_with_content(self):
195224
)
196225
assert len(events) == 1
197226

227+
@pytest.mark.asyncio
228+
async def test_async_stream_query(self):
229+
app = reasoning_engines.AdkApp(
230+
agent=Agent(name="test_agent", model=_TEST_MODEL)
231+
)
232+
assert app._tmpl_attrs.get("runner") is None
233+
app.set_up()
234+
app._tmpl_attrs["runner"] = _MockRunner()
235+
events = []
236+
async for event in app.async_stream_query(
237+
user_id="test_user_id",
238+
message="test message",
239+
):
240+
events.append(event)
241+
assert len(events) == 1
242+
243+
@pytest.mark.asyncio
244+
async def test_async_stream_query_with_content(self):
245+
app = reasoning_engines.AdkApp(
246+
agent=Agent(name="test_agent", model=_TEST_MODEL)
247+
)
248+
assert app._tmpl_attrs.get("runner") is None
249+
app.set_up()
250+
app._tmpl_attrs["runner"] = _MockRunner()
251+
events = []
252+
async for event in app.async_stream_query(
253+
user_id="test_user_id",
254+
message=types.Content(
255+
role="user",
256+
parts=[
257+
types.Part(
258+
text="test message with content",
259+
)
260+
],
261+
).model_dump(),
262+
):
263+
events.append(event)
264+
assert len(events) == 1
265+
198266
def test_streaming_agent_run_with_events(self):
199267
app = reasoning_engines.AdkApp(
200268
agent=Agent(name="test_agent", model=_TEST_MODEL)
@@ -322,3 +390,25 @@ def test_raise_get_session_not_found_error(self):
322390
user_id="non_existent_user",
323391
session_id="test_session_id",
324392
)
393+
394+
def test_stream_query_invalid_message_type(self):
395+
app = reasoning_engines.AdkApp(
396+
agent=Agent(name="test_agent", model=_TEST_MODEL)
397+
)
398+
with pytest.raises(
399+
TypeError,
400+
match="message must be a string or a dictionary representing a Content object.",
401+
):
402+
list(app.stream_query(user_id="test_user_id", message=123))
403+
404+
@pytest.mark.asyncio
405+
async def test_async_stream_query_invalid_message_type(self):
406+
app = reasoning_engines.AdkApp(
407+
agent=Agent(name="test_agent", model=_TEST_MODEL)
408+
)
409+
with pytest.raises(
410+
TypeError,
411+
match="message must be a string or a dictionary representing a Content object.",
412+
):
413+
async for _ in app.async_stream_query(user_id="test_user_id", message=123):
414+
pass

0 commit comments

Comments
 (0)