Skip to content

Commit 1b7d29c

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add async support for reasoning engine.
PiperOrigin-RevId: 763862748
1 parent 9427b15 commit 1b7d29c

File tree

5 files changed

+245
-83
lines changed

5 files changed

+245
-83
lines changed

google/cloud/aiplatform/compat/services/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@
179179
from google.cloud.aiplatform_v1.services.reasoning_engine_execution_service import (
180180
client as reasoning_engine_execution_service_client_v1,
181181
)
182+
from google.cloud.aiplatform_v1.services.reasoning_engine_execution_service import (
183+
async_client as reasoning_engine_execution_async_client_v1,
184+
)
182185
from google.cloud.aiplatform_v1.services.schedule_service import (
183186
client as schedule_service_client_v1,
184187
)

google/cloud/aiplatform/utils/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
prediction_service_async_client_v1,
9191
reasoning_engine_service_client_v1,
9292
reasoning_engine_execution_service_client_v1,
93+
reasoning_engine_execution_async_client_v1,
9394
schedule_service_client_v1,
9495
tensorboard_service_client_v1,
9596
vizier_service_client_v1,
@@ -1007,6 +1008,17 @@ class AgentEngineExecutionClientWithOverride(ClientWithOverride):
10071008
)
10081009

10091010

1011+
class AgentEngineExecutionAsyncClientWithOverride(ClientWithOverride):
1012+
_is_temporary = True
1013+
_default_version = compat.V1
1014+
_version_map = (
1015+
(
1016+
compat.V1,
1017+
reasoning_engine_execution_async_client_v1.ReasoningEngineExecutionServiceAsyncClient,
1018+
),
1019+
)
1020+
1021+
10101022
class VertexRagDataClientWithOverride(ClientWithOverride):
10111023
_is_temporary = True
10121024
_default_version = compat.DEFAULT_VERSION

tests/unit/vertex_langchain/test_agent_engines.py

Lines changed: 143 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ def clone(self):
6363
return self
6464

6565

66+
class AsyncQueryEngine:
67+
"""A sample Agent Engine that implements `async_query`."""
68+
69+
def set_up(self):
70+
pass
71+
72+
async def async_query(self, unused_arbitrary_string_name: str):
73+
"""Runs the query asynchronously."""
74+
return unused_arbitrary_string_name.upper()
75+
76+
def clone(self):
77+
return self
78+
79+
6680
class AsyncStreamQueryEngine:
6781
"""A sample Agent Engine that implements `async_stream_query`."""
6882

@@ -104,10 +118,18 @@ def query(self, unused_arbitrary_string_name: str) -> str:
104118
"""Runs the engine."""
105119
return unused_arbitrary_string_name.upper()
106120

121+
async def async_query(self, unused_arbitrary_string_name: str) -> str:
122+
"""Runs the query asynchronously."""
123+
return unused_arbitrary_string_name.upper()
124+
107125
# Add a custom method to test the custom method registration.
108126
def custom_method(self, x: str) -> str:
109127
return x.upper()
110128

129+
# Add a custom async method to test the custom async method registration.
130+
async def custom_async_method(self, x: str):
131+
return x.upper()
132+
111133
def stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]:
112134
"""Runs the stream engine."""
113135
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE:
@@ -146,6 +168,10 @@ def register_operations(self) -> Dict[str, List[str]]:
146168
_TEST_DEFAULT_METHOD_NAME,
147169
_TEST_CUSTOM_METHOD_NAME,
148170
],
171+
_TEST_ASYNC_API_MODE: [
172+
_TEST_DEFAULT_ASYNC_METHOD_NAME,
173+
_TEST_CUSTOM_ASYNC_METHOD_NAME,
174+
],
149175
_TEST_STREAM_API_MODE: [
150176
_TEST_DEFAULT_STREAM_METHOD_NAME,
151177
_TEST_CUSTOM_STREAM_METHOD_NAME,
@@ -164,14 +190,22 @@ def query(self, unused_arbitrary_string_name: str) -> str:
164190
"""Runs the engine."""
165191
return unused_arbitrary_string_name.upper()
166192

193+
async def async_query(self, unused_arbitrary_string_name: str) -> str:
194+
"""Runs the query asynchronously."""
195+
return unused_arbitrary_string_name.upper()
196+
167197
# Add a custom method to test the custom method registration
168198
def custom_method(self, x: str) -> str:
169199
return x.upper()
170200

171-
# Add a custom method that is not registered.ration
201+
# Add a custom method that is not registered.
172202
def custom_method_2(self, x: str) -> str:
173203
return x.upper()
174204

205+
# Add a custom async method to test the custom async method registration.
206+
async def custom_async_method(self, x: str):
207+
return x.upper()
208+
175209
def stream_query(self, unused_arbitrary_string_name: str) -> Iterable[Any]:
176210
"""Runs the stream engine."""
177211
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE:
@@ -204,6 +238,10 @@ def register_operations(self) -> Dict[str, List[str]]:
204238
_TEST_DEFAULT_METHOD_NAME,
205239
_TEST_CUSTOM_METHOD_NAME,
206240
],
241+
_TEST_ASYNC_API_MODE: [
242+
_TEST_DEFAULT_ASYNC_METHOD_NAME,
243+
_TEST_CUSTOM_ASYNC_METHOD_NAME,
244+
],
207245
_TEST_STREAM_API_MODE: [
208246
_TEST_DEFAULT_STREAM_METHOD_NAME,
209247
_TEST_CUSTOM_STREAM_METHOD_NAME,
@@ -291,9 +329,11 @@ def register_operations(self) -> Dict[str, List[str]]:
291329
_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE
292330
_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE
293331
_TEST_STANDARD_API_MODE = _agent_engines._STANDARD_API_MODE
332+
_TEST_ASYNC_API_MODE = _agent_engines._ASYNC_API_MODE
294333
_TEST_STREAM_API_MODE = _agent_engines._STREAM_API_MODE
295334
_TEST_ASYNC_STREAM_API_MODE = _agent_engines._ASYNC_STREAM_API_MODE
296335
_TEST_DEFAULT_METHOD_NAME = _agent_engines._DEFAULT_METHOD_NAME
336+
_TEST_DEFAULT_ASYNC_METHOD_NAME = _agent_engines._DEFAULT_ASYNC_METHOD_NAME
297337
_TEST_DEFAULT_STREAM_METHOD_NAME = _agent_engines._DEFAULT_STREAM_METHOD_NAME
298338
_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME = (
299339
_agent_engines._DEFAULT_ASYNC_STREAM_METHOD_NAME
@@ -304,6 +344,7 @@ def register_operations(self) -> Dict[str, List[str]]:
304344
_TEST_MODE_KEY_IN_SCHEMA = _agent_engines._MODE_KEY_IN_SCHEMA
305345
_TEST_METHOD_NAME_KEY_IN_SCHEMA = _agent_engines._METHOD_NAME_KEY_IN_SCHEMA
306346
_TEST_CUSTOM_METHOD_NAME = "custom_method"
347+
_TEST_CUSTOM_ASYNC_METHOD_NAME = "custom_async_method"
307348
_TEST_CUSTOM_STREAM_METHOD_NAME = "custom_stream_method"
308349
_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME = "custom_async_stream_method"
309350
_TEST_CUSTOM_METHOD_DEFAULT_DOCSTRING = """
@@ -320,6 +361,20 @@ def register_operations(self) -> Dict[str, List[str]]:
320361
Returns:
321362
dict[str, Any]: The response from serving the user request.
322363
"""
364+
_TEST_CUSTOM_ASYNC_METHOD_DEFAULT_DOCSTRING = """
365+
Runs the Agent Engine to serve the user request.
366+
367+
This will be based on the `.custom_async_method(...)` of the python object that
368+
was passed in when creating the Agent Engine. The method will invoke the
369+
`async_query` API client of the python object.
370+
371+
Args:
372+
**kwargs:
373+
Optional. The arguments of the `.custom_async_method(...)` method.
374+
375+
Returns:
376+
Coroutine[Any]: The response from serving the user request.
377+
"""
323378
_TEST_CUSTOM_STREAM_METHOD_DEFAULT_DOCSTRING = """
324379
Runs the Agent Engine to serve the user request.
325380
@@ -429,6 +484,13 @@ def register_operations(self) -> Dict[str, List[str]]:
429484
]
430485
_TEST_AGENT_ENGINE_OPERATION_SCHEMAS = []
431486
_TEST_AGENT_ENGINE_EXTRA_PACKAGE = "fake.py"
487+
_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA = _utils.to_proto(
488+
_utils.generate_schema(
489+
AsyncQueryEngine().async_query,
490+
schema_name=_TEST_DEFAULT_ASYNC_METHOD_NAME,
491+
)
492+
)
493+
_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA[_TEST_MODE_KEY_IN_SCHEMA] = _TEST_ASYNC_API_MODE
432494
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA = _utils.to_proto(
433495
_utils.generate_schema(
434496
OperationRegistrableEngine().custom_method,
@@ -438,6 +500,15 @@ def register_operations(self) -> Dict[str, List[str]]:
438500
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA[
439501
_TEST_MODE_KEY_IN_SCHEMA
440502
] = _TEST_STANDARD_API_MODE
503+
_TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA = _utils.to_proto(
504+
_utils.generate_schema(
505+
OperationRegistrableEngine().custom_async_method,
506+
schema_name=_TEST_CUSTOM_ASYNC_METHOD_NAME,
507+
)
508+
)
509+
_TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA[
510+
_TEST_MODE_KEY_IN_SCHEMA
511+
] = _TEST_ASYNC_API_MODE
441512
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA = _utils.to_proto(
442513
_utils.generate_schema(
443514
StreamQueryEngine().stream_query,
@@ -475,6 +546,8 @@ def register_operations(self) -> Dict[str, List[str]]:
475546
_TEST_OPERATION_REGISTRABLE_SCHEMAS = [
476547
_TEST_AGENT_ENGINE_QUERY_SCHEMA,
477548
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA,
549+
_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA,
550+
_TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA,
478551
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA,
479552
_TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA,
480553
_TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA,
@@ -499,6 +572,7 @@ def register_operations(self) -> Dict[str, List[str]]:
499572
_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA[
500573
_TEST_MODE_KEY_IN_SCHEMA
501574
] = _TEST_STANDARD_API_MODE
575+
_TEST_ASYNC_QUERY_SCHEMAS = [_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA]
502576
_TEST_STREAM_QUERY_SCHEMAS = [
503577
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA,
504578
]
@@ -758,6 +832,17 @@ def query() -> str:
758832
return "RESPONSE"
759833

760834

835+
class InvalidCapitalizeEngineWithoutAsyncQuerySelf:
836+
"""A sample Agent Engine with an invalid async_query method."""
837+
838+
def set_up(self):
839+
pass
840+
841+
async def async_query() -> str:
842+
"""Runs the engine."""
843+
return "RESPONSE"
844+
845+
761846
class InvalidCapitalizeEngineWithoutStreamQuerySelf:
762847
"""A sample Agent Engine with an invalid query_stream_query method."""
763848

@@ -1161,6 +1246,23 @@ def test_get_agent_framework(
11611246
),
11621247
),
11631248
),
1249+
(
1250+
"Update the async query engine",
1251+
{"agent_engine": AsyncQueryEngine()},
1252+
types.reasoning_engine_service.UpdateReasoningEngineRequest(
1253+
reasoning_engine=_generate_agent_engine_with_class_methods_and_agent_framework(
1254+
_TEST_ASYNC_QUERY_SCHEMAS,
1255+
_agent_engines._DEFAULT_AGENT_FRAMEWORK,
1256+
),
1257+
update_mask=field_mask_pb2.FieldMask(
1258+
paths=[
1259+
"spec.package_spec.pickle_object_gcs_uri",
1260+
"spec.class_methods",
1261+
"spec.agent_framework",
1262+
]
1263+
),
1264+
),
1265+
),
11641266
(
11651267
"Update the stream query engine",
11661268
{"agent_engine": StreamQueryEngine()},
@@ -1534,6 +1636,20 @@ def test_query_agent_engine(
15341636
),
15351637
_TEST_STANDARD_API_MODE,
15361638
),
1639+
(
1640+
_utils.generate_schema(
1641+
OperationRegistrableEngine().async_query,
1642+
schema_name=_TEST_DEFAULT_ASYNC_METHOD_NAME,
1643+
),
1644+
_TEST_ASYNC_API_MODE,
1645+
),
1646+
(
1647+
_utils.generate_schema(
1648+
OperationRegistrableEngine().custom_async_method,
1649+
schema_name=_TEST_CUSTOM_ASYNC_METHOD_NAME,
1650+
),
1651+
_TEST_ASYNC_API_MODE,
1652+
),
15371653
(
15381654
_utils.generate_schema(
15391655
OperationRegistrableEngine().stream_query,
@@ -2320,8 +2436,8 @@ def test_create_agent_engine_no_query_method(
23202436
TypeError,
23212437
match=(
23222438
"agent_engine has none of the following callable methods: "
2323-
"`query`, `stream_query`, `async_stream_query` or "
2324-
"`register_operations`."
2439+
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2440+
"or `register_operations`."
23252441
),
23262442
):
23272443
agent_engines.create(
@@ -2344,8 +2460,8 @@ def test_create_agent_engine_noncallable_query_attribute(
23442460
TypeError,
23452461
match=(
23462462
"agent_engine has none of the following callable methods: "
2347-
"`query`, `stream_query`, `async_stream_query` or "
2348-
"`register_operations`."
2463+
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2464+
"or `register_operations`."
23492465
),
23502466
):
23512467
agent_engines.create(
@@ -2406,6 +2522,23 @@ def test_create_agent_engine_with_invalid_query_method(
24062522
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
24072523
)
24082524

2525+
def test_create_agent_engine_with_invalid_async_query_method(
2526+
self,
2527+
create_agent_engine_mock,
2528+
cloud_storage_create_bucket_mock,
2529+
tarfile_open_mock,
2530+
cloudpickle_dump_mock,
2531+
cloudpickle_load_mock,
2532+
importlib_metadata_version_mock,
2533+
get_agent_engine_mock,
2534+
):
2535+
with pytest.raises(ValueError, match="Invalid async_query signature"):
2536+
agent_engines.create(
2537+
InvalidCapitalizeEngineWithoutAsyncQuerySelf(),
2538+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2539+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
2540+
)
2541+
24092542
def test_create_agent_engine_with_invalid_stream_query_method(
24102543
self,
24112544
create_agent_engine_mock,
@@ -2574,8 +2707,8 @@ def test_update_agent_engine_no_query_method(
25742707
TypeError,
25752708
match=(
25762709
"agent_engine has none of the following callable methods: "
2577-
"`query`, `stream_query`, `async_stream_query` or "
2578-
"`register_operations`."
2710+
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2711+
"or `register_operations`."
25792712
),
25802713
):
25812714
test_agent_engine = _generate_agent_engine_to_update()
@@ -2597,8 +2730,8 @@ def test_update_agent_engine_noncallable_query_attribute(
25972730
TypeError,
25982731
match=(
25992732
"agent_engine has none of the following callable methods: "
2600-
"`query`, `stream_query`, `async_stream_query` or "
2601-
"`register_operations`."
2733+
"`query`, `async_query`, `stream_query`, `async_stream_query` "
2734+
"or `register_operations`."
26022735
),
26032736
):
26042737
test_agent_engine = _generate_agent_engine_to_update()
@@ -2737,7 +2870,7 @@ def test_update_class_methods_spec_with_registered_operation_not_found(self):
27372870
"register the API methods: "
27382871
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
27392872
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
2740-
"Supported modes are: ``, `stream` and `async_stream`.}"
2873+
"Supported modes are: ``, `async`, `stream` and `async_stream`.}"
27412874
),
27422875
),
27432876
],

vertexai/agent_engines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ModuleAgent,
3030
OperationRegistrable,
3131
Queryable,
32+
AsyncQueryable,
3233
StreamQueryable,
3334
AsyncStreamQueryable,
3435
)
@@ -313,6 +314,7 @@ def update(
313314
"Cloneable",
314315
"OperationRegistrable",
315316
"Queryable",
317+
"AsyncQueryable",
316318
"StreamQueryable",
317319
"AsyncStreamQueryable",
318320
# Methods

0 commit comments

Comments
 (0)