Skip to content

Commit 74e3f25

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Add live/bidi agent deployment support for Agent Engine
PiperOrigin-RevId: 803656127
1 parent d5a14ba commit 74e3f25

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,20 @@ async def custom_async_stream_method(
149149
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE:
150150
yield chunk
151151

152+
async def bidi_stream_query(self, input_queue: asyncio.Queue) -> AsyncIterable[Any]:
153+
"""Runs the bidi stream engine."""
154+
while True:
155+
chunk = await input_queue.get()
156+
yield chunk
157+
158+
async def custom_bidi_stream_method(
159+
self, input_queue: asyncio.Queue
160+
) -> AsyncIterable[Any]:
161+
"""Runs the async bidi stream engine."""
162+
while True:
163+
chunk = await input_queue.get()
164+
yield chunk
165+
152166
def clone(self):
153167
return self
154168

@@ -170,6 +184,10 @@ def register_operations(self) -> Dict[str, List[str]]:
170184
_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME,
171185
_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME,
172186
],
187+
_TEST_BIDI_STREAM_API_MODE: [
188+
_TEST_DEFAULT_BIDI_STREAM_METHOD_NAME,
189+
_TEST_CUSTOM_BIDI_STREAM_METHOD_NAME,
190+
],
173191
}
174192

175193

@@ -323,21 +341,27 @@ def register_operations(self) -> Dict[str, List[str]]:
323341
_TEST_ASYNC_API_MODE = _agent_engines_utils._ASYNC_API_MODE
324342
_TEST_STREAM_API_MODE = _agent_engines_utils._STREAM_API_MODE
325343
_TEST_ASYNC_STREAM_API_MODE = _agent_engines_utils._ASYNC_STREAM_API_MODE
344+
_TEST_BIDI_STREAM_API_MODE = _agent_engines_utils._BIDI_STREAM_API_MODE
326345
_TEST_DEFAULT_METHOD_NAME = _agent_engines_utils._DEFAULT_METHOD_NAME
327346
_TEST_DEFAULT_ASYNC_METHOD_NAME = _agent_engines_utils._DEFAULT_ASYNC_METHOD_NAME
328347
_TEST_DEFAULT_STREAM_METHOD_NAME = _agent_engines_utils._DEFAULT_STREAM_METHOD_NAME
329348
_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME = (
330349
_agent_engines_utils._DEFAULT_ASYNC_STREAM_METHOD_NAME
331350
)
351+
_TEST_DEFAULT_BIDI_STREAM_METHOD_NAME = (
352+
_agent_engines_utils._DEFAULT_BIDI_STREAM_METHOD_NAME
353+
)
332354
_TEST_CAPITALIZE_ENGINE_METHOD_DOCSTRING = "Runs the engine."
333355
_TEST_STREAM_METHOD_DOCSTRING = "Runs the stream engine."
334356
_TEST_ASYNC_STREAM_METHOD_DOCSTRING = "Runs the async stream engine."
357+
_TEST_BIDI_STREAM_METHOD_DOCSTRING = "Runs the bidi stream engine."
335358
_TEST_MODE_KEY_IN_SCHEMA = _agent_engines_utils._MODE_KEY_IN_SCHEMA
336359
_TEST_METHOD_NAME_KEY_IN_SCHEMA = _agent_engines_utils._METHOD_NAME_KEY_IN_SCHEMA
337360
_TEST_CUSTOM_METHOD_NAME = "custom_method"
338361
_TEST_CUSTOM_ASYNC_METHOD_NAME = "custom_async_method"
339362
_TEST_CUSTOM_STREAM_METHOD_NAME = "custom_stream_method"
340363
_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME = "custom_async_stream_method"
364+
_TEST_CUSTOM_BIDI_STREAM_METHOD_NAME = "custom_bidi_stream_method"
341365
_TEST_CUSTOM_METHOD_DEFAULT_DOCSTRING = """
342366
Runs the Agent Engine to serve the user request.
343367
@@ -579,6 +603,22 @@ def register_operations(self) -> Dict[str, List[str]]:
579603
_TEST_AGENT_ENGINE_CUSTOM_ASYNC_STREAM_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = (
580604
_TEST_ASYNC_STREAM_API_MODE
581605
)
606+
_TEST_AGENT_ENGINE_BIDI_STREAM_QUERY_SCHEMA = _agent_engines_utils._generate_schema(
607+
OperationRegistrableEngine().bidi_stream_query,
608+
schema_name=_TEST_DEFAULT_BIDI_STREAM_METHOD_NAME,
609+
)
610+
_TEST_AGENT_ENGINE_BIDI_STREAM_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = (
611+
_TEST_BIDI_STREAM_API_MODE
612+
)
613+
_TEST_AGENT_ENGINE_CUSTOM_BIDI_STREAM_QUERY_SCHEMA = (
614+
_agent_engines_utils._generate_schema(
615+
OperationRegistrableEngine().custom_bidi_stream_method,
616+
schema_name=_TEST_CUSTOM_BIDI_STREAM_METHOD_NAME,
617+
)
618+
)
619+
_TEST_AGENT_ENGINE_CUSTOM_BIDI_STREAM_QUERY_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = (
620+
_TEST_BIDI_STREAM_API_MODE
621+
)
582622
_TEST_OPERATION_REGISTRABLE_SCHEMAS = [
583623
_TEST_AGENT_ENGINE_QUERY_SCHEMA,
584624
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA,
@@ -588,6 +628,8 @@ def register_operations(self) -> Dict[str, List[str]]:
588628
_TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA,
589629
_TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA,
590630
_TEST_AGENT_ENGINE_CUSTOM_ASYNC_STREAM_QUERY_SCHEMA,
631+
_TEST_AGENT_ENGINE_BIDI_STREAM_QUERY_SCHEMA,
632+
_TEST_AGENT_ENGINE_CUSTOM_BIDI_STREAM_QUERY_SCHEMA,
591633
]
592634
_TEST_OPERATION_NOT_REGISTERED_SCHEMAS = [
593635
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA,
@@ -1873,6 +1915,20 @@ async def consume():
18731915
),
18741916
_TEST_ASYNC_STREAM_API_MODE,
18751917
),
1918+
(
1919+
_agent_engines_utils._generate_schema(
1920+
OperationRegistrableEngine().bidi_stream_query,
1921+
schema_name=_TEST_DEFAULT_BIDI_STREAM_METHOD_NAME,
1922+
),
1923+
_TEST_BIDI_STREAM_API_MODE,
1924+
),
1925+
(
1926+
_agent_engines_utils._generate_schema(
1927+
OperationRegistrableEngine().custom_bidi_stream_method,
1928+
schema_name=_TEST_CUSTOM_BIDI_STREAM_METHOD_NAME,
1929+
),
1930+
_TEST_BIDI_STREAM_API_MODE,
1931+
),
18761932
],
18771933
),
18781934
(

vertexai/_genai/_agent_engines_utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Utility functions for agent engines."""
1616

1717
import abc
18+
import asyncio
1819
from importlib import metadata as importlib_metadata
1920
import inspect
2021
import io
@@ -108,6 +109,7 @@
108109
_AGENT_FRAMEWORK_ATTR = "agent_framework"
109110
_ASYNC_API_MODE = "async"
110111
_ASYNC_STREAM_API_MODE = "async_stream"
112+
_BIDI_STREAM_API_MODE = "bidi_stream"
111113
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
112114
_BLOB_FILENAME = "agent_engine.pkl"
113115
_DEFAULT_AGENT_FRAMEWORK = "custom"
@@ -132,6 +134,7 @@
132134
_DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]"
133135
_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"])
134136
_DEFAULT_STREAM_METHOD_NAME = "stream_query"
137+
_DEFAULT_BIDI_STREAM_METHOD_NAME = "bidi_stream_query"
135138
_EXTRA_PACKAGES_FILE = "dependencies.tar.gz"
136139
_FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE = (
137140
"Failed to register API methods. Please follow the guide to "
@@ -202,6 +205,15 @@ def stream_query(self, **kwargs) -> Iterator[Any]: # type: ignore[no-untyped-de
202205
"""Stream responses to serve the user query."""
203206

204207

208+
@typing.runtime_checkable
209+
class BidiStreamQueryable(Protocol):
210+
"""Protocol for Agent Engines that can stream requests and responses."""
211+
212+
@abc.abstractmethod
213+
async def bidi_stream_query(self, input_queue: asyncio.Queue) -> AsyncIterator[Any]:
214+
"""Stream requests and responses to serve the user queries."""
215+
216+
205217
@typing.runtime_checkable
206218
class Cloneable(Protocol):
207219
"""Protocol for Agent Engines that can be cloned."""
@@ -234,6 +246,7 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]:
234246
OperationRegistrable,
235247
Queryable,
236248
StreamQueryable,
249+
BidiStreamQueryable,
237250
]
238251

239252

@@ -557,6 +570,9 @@ def _generate_schema(
557570
inspect.Parameter.KEYWORD_ONLY,
558571
inspect.Parameter.POSITIONAL_ONLY,
559572
)
573+
# For a bidi endpoint, it requires an asyncio.Queue as the input, but
574+
# it is not JSON serializable. We hence exclude it from the schema.
575+
and param.annotation != asyncio.Queue
560576
}
561577
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
562578
# Postprocessing
@@ -656,6 +672,8 @@ def _get_registered_operations(
656672
operations[_STREAM_API_MODE] = [_DEFAULT_STREAM_METHOD_NAME]
657673
if isinstance(agent, AsyncStreamQueryable):
658674
operations[_ASYNC_STREAM_API_MODE] = [_DEFAULT_ASYNC_STREAM_METHOD_NAME]
675+
if isinstance(agent, BidiStreamQueryable):
676+
operations[_BIDI_STREAM_API_MODE] = [_DEFAULT_BIDI_STREAM_METHOD_NAME]
659677
return operations
660678

661679

@@ -839,6 +857,10 @@ def _register_api_methods_or_raise(
839857
f" contain an `{_MODE_KEY_IN_SCHEMA}` field."
840858
)
841859
api_mode = operation_schema.get(_MODE_KEY_IN_SCHEMA)
860+
# For bidi stream api mode, we don't need to wrap the operation.
861+
if api_mode == _BIDI_STREAM_API_MODE:
862+
continue
863+
842864
if _METHOD_NAME_KEY_IN_SCHEMA not in operation_schema:
843865
raise ValueError(
844866
f"Operation schema {operation_schema} does not"
@@ -1212,6 +1234,7 @@ def _validate_agent_or_raise(
12121234
* a callable method named `query`
12131235
* a callable method named `stream_query`
12141236
* a callable method named `async_stream_query`
1237+
* a callable method named `bidi_stream_query`
12151238
* a callable method named `register_operations`
12161239
12171240
Args:
@@ -1246,6 +1269,9 @@ def _validate_agent_or_raise(
12461269
is_async_stream_queryable = isinstance(agent, AsyncStreamQueryable) and callable(
12471270
agent.async_stream_query
12481271
)
1272+
is_bidi_stream_queryable = isinstance(agent, BidiStreamQueryable) and callable(
1273+
agent.bidi_stream_query
1274+
)
12491275
is_operation_registrable = isinstance(agent, OperationRegistrable) and callable(
12501276
agent.register_operations
12511277
)
@@ -1255,12 +1281,13 @@ def _validate_agent_or_raise(
12551281
or is_async_queryable
12561282
or is_stream_queryable
12571283
or is_operation_registrable
1284+
or is_bidi_stream_queryable
12581285
or is_async_stream_queryable
12591286
):
12601287
raise TypeError(
12611288
"agent_engine has none of the following callable methods: "
1262-
"`query`, `async_query`, `stream_query`, `async_stream_query` or "
1263-
"`register_operations`."
1289+
"`query`, `async_query`, `stream_query`, `async_stream_query`, "
1290+
"`bidi_stream_query`, or `register_operations`."
12641291
)
12651292

12661293
if is_queryable:
@@ -1299,6 +1326,15 @@ def _validate_agent_or_raise(
12991326
" missing `self` argument in the agent.async_stream_query method."
13001327
) from err
13011328

1329+
if is_bidi_stream_queryable:
1330+
try:
1331+
inspect.signature(getattr(agent, "bidi_stream_query"))
1332+
except ValueError as err:
1333+
raise ValueError(
1334+
"Invalid bidi_stream_query signature. This might be due to a "
1335+
" missing `self` argument in the agent.bidi_stream_query method."
1336+
) from err
1337+
13021338
if is_operation_registrable:
13031339
try:
13041340
inspect.signature(getattr(agent, "register_operations"))

vertexai/_genai/agent_engines.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,18 @@ def _create_config(
12201220
_agent_engines_utils._to_dict(class_method)
12211221
for class_method in class_methods
12221222
]
1223+
# Set the agent_server_mode to EXPERIMENTAL if the agent has a
1224+
# bidi_stream method.
1225+
for class_method in class_methods:
1226+
if class_method["api_mode"] == "bidi_stream":
1227+
if not agent_engine_spec.get("deployment_spec"):
1228+
agent_engine_spec["deployment_spec"] = (
1229+
types.ReasoningEngineSpecDeploymentSpecDict()
1230+
)
1231+
agent_engine_spec["deployment_spec"][
1232+
"agent_server_mode"
1233+
] = types.AgentServerMode.EXPERIMENTAL
1234+
break
12231235
update_masks.append("spec.class_methods")
12241236
agent_engine_spec["agent_framework"] = (
12251237
_agent_engines_utils._get_agent_framework(agent=agent)

0 commit comments

Comments
 (0)