@@ -63,6 +63,20 @@ def clone(self):
63
63
return self
64
64
65
65
66
+ class AsyncQueryEngine :
67
+ """A sample Agent Engine that implements `async_query`."""
68
+
69
+ def set_up (self ):
70
+ pass
71
+
72
+ async def async_query (self , unused_arbitrary_string_name : str ):
73
+ """Runs the query asynchronously."""
74
+ return unused_arbitrary_string_name .upper ()
75
+
76
+ def clone (self ):
77
+ return self
78
+
79
+
66
80
class AsyncStreamQueryEngine :
67
81
"""A sample Agent Engine that implements `async_stream_query`."""
68
82
@@ -104,10 +118,18 @@ def query(self, unused_arbitrary_string_name: str) -> str:
104
118
"""Runs the engine."""
105
119
return unused_arbitrary_string_name .upper ()
106
120
121
+ async def async_query (self , unused_arbitrary_string_name : str ) -> str :
122
+ """Runs the query asynchronously."""
123
+ return unused_arbitrary_string_name .upper ()
124
+
107
125
# Add a custom method to test the custom method registration.
108
126
def custom_method (self , x : str ) -> str :
109
127
return x .upper ()
110
128
129
+ # Add a custom async method to test the custom async method registration.
130
+ async def custom_async_method (self , x : str ):
131
+ return x .upper ()
132
+
111
133
def stream_query (self , unused_arbitrary_string_name : str ) -> Iterable [Any ]:
112
134
"""Runs the stream engine."""
113
135
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE :
@@ -146,6 +168,10 @@ def register_operations(self) -> Dict[str, List[str]]:
146
168
_TEST_DEFAULT_METHOD_NAME ,
147
169
_TEST_CUSTOM_METHOD_NAME ,
148
170
],
171
+ _TEST_ASYNC_API_MODE : [
172
+ _TEST_DEFAULT_ASYNC_METHOD_NAME ,
173
+ _TEST_CUSTOM_ASYNC_METHOD_NAME ,
174
+ ],
149
175
_TEST_STREAM_API_MODE : [
150
176
_TEST_DEFAULT_STREAM_METHOD_NAME ,
151
177
_TEST_CUSTOM_STREAM_METHOD_NAME ,
@@ -164,14 +190,22 @@ def query(self, unused_arbitrary_string_name: str) -> str:
164
190
"""Runs the engine."""
165
191
return unused_arbitrary_string_name .upper ()
166
192
193
+ async def async_query (self , unused_arbitrary_string_name : str ) -> str :
194
+ """Runs the query asynchronously."""
195
+ return unused_arbitrary_string_name .upper ()
196
+
167
197
# Add a custom method to test the custom method registration
168
198
def custom_method (self , x : str ) -> str :
169
199
return x .upper ()
170
200
171
- # Add a custom method that is not registered.ration
201
+ # Add a custom method that is not registered.
172
202
def custom_method_2 (self , x : str ) -> str :
173
203
return x .upper ()
174
204
205
+ # Add a custom async method to test the custom async method registration.
206
+ async def custom_async_method (self , x : str ):
207
+ return x .upper ()
208
+
175
209
def stream_query (self , unused_arbitrary_string_name : str ) -> Iterable [Any ]:
176
210
"""Runs the stream engine."""
177
211
for chunk in _TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE :
@@ -204,6 +238,10 @@ def register_operations(self) -> Dict[str, List[str]]:
204
238
_TEST_DEFAULT_METHOD_NAME ,
205
239
_TEST_CUSTOM_METHOD_NAME ,
206
240
],
241
+ _TEST_ASYNC_API_MODE : [
242
+ _TEST_DEFAULT_ASYNC_METHOD_NAME ,
243
+ _TEST_CUSTOM_ASYNC_METHOD_NAME ,
244
+ ],
207
245
_TEST_STREAM_API_MODE : [
208
246
_TEST_DEFAULT_STREAM_METHOD_NAME ,
209
247
_TEST_CUSTOM_STREAM_METHOD_NAME ,
@@ -291,9 +329,11 @@ def register_operations(self) -> Dict[str, List[str]]:
291
329
_TEST_REQUIREMENTS_FILE = _agent_engines ._REQUIREMENTS_FILE
292
330
_TEST_EXTRA_PACKAGES_FILE = _agent_engines ._EXTRA_PACKAGES_FILE
293
331
_TEST_STANDARD_API_MODE = _agent_engines ._STANDARD_API_MODE
332
+ _TEST_ASYNC_API_MODE = _agent_engines ._ASYNC_API_MODE
294
333
_TEST_STREAM_API_MODE = _agent_engines ._STREAM_API_MODE
295
334
_TEST_ASYNC_STREAM_API_MODE = _agent_engines ._ASYNC_STREAM_API_MODE
296
335
_TEST_DEFAULT_METHOD_NAME = _agent_engines ._DEFAULT_METHOD_NAME
336
+ _TEST_DEFAULT_ASYNC_METHOD_NAME = _agent_engines ._DEFAULT_ASYNC_METHOD_NAME
297
337
_TEST_DEFAULT_STREAM_METHOD_NAME = _agent_engines ._DEFAULT_STREAM_METHOD_NAME
298
338
_TEST_DEFAULT_ASYNC_STREAM_METHOD_NAME = (
299
339
_agent_engines ._DEFAULT_ASYNC_STREAM_METHOD_NAME
@@ -304,6 +344,7 @@ def register_operations(self) -> Dict[str, List[str]]:
304
344
_TEST_MODE_KEY_IN_SCHEMA = _agent_engines ._MODE_KEY_IN_SCHEMA
305
345
_TEST_METHOD_NAME_KEY_IN_SCHEMA = _agent_engines ._METHOD_NAME_KEY_IN_SCHEMA
306
346
_TEST_CUSTOM_METHOD_NAME = "custom_method"
347
+ _TEST_CUSTOM_ASYNC_METHOD_NAME = "custom_async_method"
307
348
_TEST_CUSTOM_STREAM_METHOD_NAME = "custom_stream_method"
308
349
_TEST_CUSTOM_ASYNC_STREAM_METHOD_NAME = "custom_async_stream_method"
309
350
_TEST_CUSTOM_METHOD_DEFAULT_DOCSTRING = """
@@ -320,6 +361,20 @@ def register_operations(self) -> Dict[str, List[str]]:
320
361
Returns:
321
362
dict[str, Any]: The response from serving the user request.
322
363
"""
364
+ _TEST_CUSTOM_ASYNC_METHOD_DEFAULT_DOCSTRING = """
365
+ Runs the Agent Engine to serve the user request.
366
+
367
+ This will be based on the `.custom_async_method(...)` of the python object that
368
+ was passed in when creating the Agent Engine. The method will invoke the
369
+ `async_query` API client of the python object.
370
+
371
+ Args:
372
+ **kwargs:
373
+ Optional. The arguments of the `.custom_async_method(...)` method.
374
+
375
+ Returns:
376
+ Coroutine[Any]: The response from serving the user request.
377
+ """
323
378
_TEST_CUSTOM_STREAM_METHOD_DEFAULT_DOCSTRING = """
324
379
Runs the Agent Engine to serve the user request.
325
380
@@ -429,6 +484,13 @@ def register_operations(self) -> Dict[str, List[str]]:
429
484
]
430
485
_TEST_AGENT_ENGINE_OPERATION_SCHEMAS = []
431
486
_TEST_AGENT_ENGINE_EXTRA_PACKAGE = "fake.py"
487
+ _TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA = _utils .to_proto (
488
+ _utils .generate_schema (
489
+ AsyncQueryEngine ().async_query ,
490
+ schema_name = _TEST_DEFAULT_ASYNC_METHOD_NAME ,
491
+ )
492
+ )
493
+ _TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA [_TEST_MODE_KEY_IN_SCHEMA ] = _TEST_ASYNC_API_MODE
432
494
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA = _utils .to_proto (
433
495
_utils .generate_schema (
434
496
OperationRegistrableEngine ().custom_method ,
@@ -438,6 +500,15 @@ def register_operations(self) -> Dict[str, List[str]]:
438
500
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA [
439
501
_TEST_MODE_KEY_IN_SCHEMA
440
502
] = _TEST_STANDARD_API_MODE
503
+ _TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA = _utils .to_proto (
504
+ _utils .generate_schema (
505
+ OperationRegistrableEngine ().custom_async_method ,
506
+ schema_name = _TEST_CUSTOM_ASYNC_METHOD_NAME ,
507
+ )
508
+ )
509
+ _TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA [
510
+ _TEST_MODE_KEY_IN_SCHEMA
511
+ ] = _TEST_ASYNC_API_MODE
441
512
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA = _utils .to_proto (
442
513
_utils .generate_schema (
443
514
StreamQueryEngine ().stream_query ,
@@ -475,6 +546,8 @@ def register_operations(self) -> Dict[str, List[str]]:
475
546
_TEST_OPERATION_REGISTRABLE_SCHEMAS = [
476
547
_TEST_AGENT_ENGINE_QUERY_SCHEMA ,
477
548
_TEST_AGENT_ENGINE_CUSTOM_METHOD_SCHEMA ,
549
+ _TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA ,
550
+ _TEST_AGENT_ENGINE_ASYNC_CUSTOM_METHOD_SCHEMA ,
478
551
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA ,
479
552
_TEST_AGENT_ENGINE_CUSTOM_STREAM_QUERY_SCHEMA ,
480
553
_TEST_AGENT_ENGINE_ASYNC_STREAM_QUERY_SCHEMA ,
@@ -499,6 +572,7 @@ def register_operations(self) -> Dict[str, List[str]]:
499
572
_TEST_METHOD_TO_BE_UNREGISTERED_SCHEMA [
500
573
_TEST_MODE_KEY_IN_SCHEMA
501
574
] = _TEST_STANDARD_API_MODE
575
+ _TEST_ASYNC_QUERY_SCHEMAS = [_TEST_AGENT_ENGINE_ASYNC_METHOD_SCHEMA ]
502
576
_TEST_STREAM_QUERY_SCHEMAS = [
503
577
_TEST_AGENT_ENGINE_STREAM_QUERY_SCHEMA ,
504
578
]
@@ -758,6 +832,17 @@ def query() -> str:
758
832
return "RESPONSE"
759
833
760
834
835
+ class InvalidCapitalizeEngineWithoutAsyncQuerySelf :
836
+ """A sample Agent Engine with an invalid async_query method."""
837
+
838
+ def set_up (self ):
839
+ pass
840
+
841
+ async def async_query () -> str :
842
+ """Runs the engine."""
843
+ return "RESPONSE"
844
+
845
+
761
846
class InvalidCapitalizeEngineWithoutStreamQuerySelf :
762
847
"""A sample Agent Engine with an invalid query_stream_query method."""
763
848
@@ -1161,6 +1246,23 @@ def test_get_agent_framework(
1161
1246
),
1162
1247
),
1163
1248
),
1249
+ (
1250
+ "Update the async query engine" ,
1251
+ {"agent_engine" : AsyncQueryEngine ()},
1252
+ types .reasoning_engine_service .UpdateReasoningEngineRequest (
1253
+ reasoning_engine = _generate_agent_engine_with_class_methods_and_agent_framework (
1254
+ _TEST_ASYNC_QUERY_SCHEMAS ,
1255
+ _agent_engines ._DEFAULT_AGENT_FRAMEWORK ,
1256
+ ),
1257
+ update_mask = field_mask_pb2 .FieldMask (
1258
+ paths = [
1259
+ "spec.package_spec.pickle_object_gcs_uri" ,
1260
+ "spec.class_methods" ,
1261
+ "spec.agent_framework" ,
1262
+ ]
1263
+ ),
1264
+ ),
1265
+ ),
1164
1266
(
1165
1267
"Update the stream query engine" ,
1166
1268
{"agent_engine" : StreamQueryEngine ()},
@@ -1534,6 +1636,20 @@ def test_query_agent_engine(
1534
1636
),
1535
1637
_TEST_STANDARD_API_MODE ,
1536
1638
),
1639
+ (
1640
+ _utils .generate_schema (
1641
+ OperationRegistrableEngine ().async_query ,
1642
+ schema_name = _TEST_DEFAULT_ASYNC_METHOD_NAME ,
1643
+ ),
1644
+ _TEST_ASYNC_API_MODE ,
1645
+ ),
1646
+ (
1647
+ _utils .generate_schema (
1648
+ OperationRegistrableEngine ().custom_async_method ,
1649
+ schema_name = _TEST_CUSTOM_ASYNC_METHOD_NAME ,
1650
+ ),
1651
+ _TEST_ASYNC_API_MODE ,
1652
+ ),
1537
1653
(
1538
1654
_utils .generate_schema (
1539
1655
OperationRegistrableEngine ().stream_query ,
@@ -2320,8 +2436,8 @@ def test_create_agent_engine_no_query_method(
2320
2436
TypeError ,
2321
2437
match = (
2322
2438
"agent_engine has none of the following callable methods: "
2323
- "`query`, `stream_query`, `async_stream_query` or "
2324
- "`register_operations`."
2439
+ "`query`, `async_query`, ` stream_query`, `async_stream_query` "
2440
+ "or `register_operations`."
2325
2441
),
2326
2442
):
2327
2443
agent_engines .create (
@@ -2344,8 +2460,8 @@ def test_create_agent_engine_noncallable_query_attribute(
2344
2460
TypeError ,
2345
2461
match = (
2346
2462
"agent_engine has none of the following callable methods: "
2347
- "`query`, `stream_query`, `async_stream_query` or "
2348
- "`register_operations`."
2463
+ "`query`, `async_query`, ` stream_query`, `async_stream_query` "
2464
+ "or `register_operations`."
2349
2465
),
2350
2466
):
2351
2467
agent_engines .create (
@@ -2406,6 +2522,23 @@ def test_create_agent_engine_with_invalid_query_method(
2406
2522
requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2407
2523
)
2408
2524
2525
+ def test_create_agent_engine_with_invalid_async_query_method (
2526
+ self ,
2527
+ create_agent_engine_mock ,
2528
+ cloud_storage_create_bucket_mock ,
2529
+ tarfile_open_mock ,
2530
+ cloudpickle_dump_mock ,
2531
+ cloudpickle_load_mock ,
2532
+ importlib_metadata_version_mock ,
2533
+ get_agent_engine_mock ,
2534
+ ):
2535
+ with pytest .raises (ValueError , match = "Invalid async_query signature" ):
2536
+ agent_engines .create (
2537
+ InvalidCapitalizeEngineWithoutAsyncQuerySelf (),
2538
+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
2539
+ requirements = _TEST_AGENT_ENGINE_REQUIREMENTS ,
2540
+ )
2541
+
2409
2542
def test_create_agent_engine_with_invalid_stream_query_method (
2410
2543
self ,
2411
2544
create_agent_engine_mock ,
@@ -2574,8 +2707,8 @@ def test_update_agent_engine_no_query_method(
2574
2707
TypeError ,
2575
2708
match = (
2576
2709
"agent_engine has none of the following callable methods: "
2577
- "`query`, `stream_query`, `async_stream_query` or "
2578
- "`register_operations`."
2710
+ "`query`, `async_query`, ` stream_query`, `async_stream_query` "
2711
+ "or `register_operations`."
2579
2712
),
2580
2713
):
2581
2714
test_agent_engine = _generate_agent_engine_to_update ()
@@ -2597,8 +2730,8 @@ def test_update_agent_engine_noncallable_query_attribute(
2597
2730
TypeError ,
2598
2731
match = (
2599
2732
"agent_engine has none of the following callable methods: "
2600
- "`query`, `stream_query`, `async_stream_query` or "
2601
- "`register_operations`."
2733
+ "`query`, `async_query`, ` stream_query`, `async_stream_query` "
2734
+ "or `register_operations`."
2602
2735
),
2603
2736
):
2604
2737
test_agent_engine = _generate_agent_engine_to_update ()
@@ -2737,7 +2870,7 @@ def test_update_class_methods_spec_with_registered_operation_not_found(self):
2737
2870
"register the API methods: "
2738
2871
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
2739
2872
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
2740
- "Supported modes are: ``, `stream` and `async_stream`.}"
2873
+ "Supported modes are: ``, `async`, ` stream` and `async_stream`.}"
2741
2874
),
2742
2875
),
2743
2876
],
0 commit comments