15
15
"""Utility functions for agent engines."""
16
16
17
17
import abc
18
+ import asyncio
18
19
from importlib import metadata as importlib_metadata
19
20
import inspect
20
21
import io
108
109
_AGENT_FRAMEWORK_ATTR = "agent_framework"
109
110
_ASYNC_API_MODE = "async"
110
111
_ASYNC_STREAM_API_MODE = "async_stream"
112
+ _BIDI_STREAM_API_MODE = "bidi_stream"
111
113
_BASE_MODULES = set (_BUILTIN_MODULE_NAMES + tuple (_STDLIB_MODULE_NAMES ))
112
114
_BLOB_FILENAME = "agent_engine.pkl"
113
115
_DEFAULT_AGENT_FRAMEWORK = "custom"
132
134
_DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]"
133
135
_DEFAULT_REQUIRED_PACKAGES = frozenset (["cloudpickle" , "pydantic" ])
134
136
_DEFAULT_STREAM_METHOD_NAME = "stream_query"
137
+ _DEFAULT_BIDI_STREAM_METHOD_NAME = "bidi_stream_query"
135
138
_EXTRA_PACKAGES_FILE = "dependencies.tar.gz"
136
139
_FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE = (
137
140
"Failed to register API methods. Please follow the guide to "
@@ -202,6 +205,15 @@ def stream_query(self, **kwargs) -> Iterator[Any]: # type: ignore[no-untyped-de
202
205
"""Stream responses to serve the user query."""
203
206
204
207
208
+ @typing .runtime_checkable
209
+ class BidiStreamQueryable (Protocol ):
210
+ """Protocol for Agent Engines that can stream requests and responses."""
211
+
212
+ @abc .abstractmethod
213
+ async def bidi_stream_query (self , input_queue : asyncio .Queue ) -> AsyncIterator [Any ]:
214
+ """Stream requests and responses to serve the user queries."""
215
+
216
+
205
217
@typing .runtime_checkable
206
218
class Cloneable (Protocol ):
207
219
"""Protocol for Agent Engines that can be cloned."""
@@ -234,6 +246,7 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]:
234
246
OperationRegistrable ,
235
247
Queryable ,
236
248
StreamQueryable ,
249
+ BidiStreamQueryable ,
237
250
]
238
251
239
252
@@ -557,6 +570,9 @@ def _generate_schema(
557
570
inspect .Parameter .KEYWORD_ONLY ,
558
571
inspect .Parameter .POSITIONAL_ONLY ,
559
572
)
573
+ # For a bidi endpoint, it requires an asyncio.Queue as the input, but
574
+ # it is not JSON serializable. We hence exclude it from the schema.
575
+ and param .annotation != asyncio .Queue
560
576
}
561
577
parameters = pydantic .create_model (f .__name__ , ** fields_dict ).schema ()
562
578
# Postprocessing
@@ -656,6 +672,8 @@ def _get_registered_operations(
656
672
operations [_STREAM_API_MODE ] = [_DEFAULT_STREAM_METHOD_NAME ]
657
673
if isinstance (agent , AsyncStreamQueryable ):
658
674
operations [_ASYNC_STREAM_API_MODE ] = [_DEFAULT_ASYNC_STREAM_METHOD_NAME ]
675
+ if isinstance (agent , BidiStreamQueryable ):
676
+ operations [_BIDI_STREAM_API_MODE ] = [_DEFAULT_BIDI_STREAM_METHOD_NAME ]
659
677
return operations
660
678
661
679
@@ -839,6 +857,10 @@ def _register_api_methods_or_raise(
839
857
f" contain an `{ _MODE_KEY_IN_SCHEMA } ` field."
840
858
)
841
859
api_mode = operation_schema .get (_MODE_KEY_IN_SCHEMA )
860
+ # For bidi stream api mode, we don't need to wrap the operation.
861
+ if api_mode == _BIDI_STREAM_API_MODE :
862
+ continue
863
+
842
864
if _METHOD_NAME_KEY_IN_SCHEMA not in operation_schema :
843
865
raise ValueError (
844
866
f"Operation schema { operation_schema } does not"
@@ -1212,6 +1234,7 @@ def _validate_agent_or_raise(
1212
1234
* a callable method named `query`
1213
1235
* a callable method named `stream_query`
1214
1236
* a callable method named `async_stream_query`
1237
+ * a callable method named `bidi_stream_query`
1215
1238
* a callable method named `register_operations`
1216
1239
1217
1240
Args:
@@ -1246,6 +1269,9 @@ def _validate_agent_or_raise(
1246
1269
is_async_stream_queryable = isinstance (agent , AsyncStreamQueryable ) and callable (
1247
1270
agent .async_stream_query
1248
1271
)
1272
+ is_bidi_stream_queryable = isinstance (agent , BidiStreamQueryable ) and callable (
1273
+ agent .bidi_stream_query
1274
+ )
1249
1275
is_operation_registrable = isinstance (agent , OperationRegistrable ) and callable (
1250
1276
agent .register_operations
1251
1277
)
@@ -1255,12 +1281,13 @@ def _validate_agent_or_raise(
1255
1281
or is_async_queryable
1256
1282
or is_stream_queryable
1257
1283
or is_operation_registrable
1284
+ or is_bidi_stream_queryable
1258
1285
or is_async_stream_queryable
1259
1286
):
1260
1287
raise TypeError (
1261
1288
"agent_engine has none of the following callable methods: "
1262
- "`query`, `async_query`, `stream_query`, `async_stream_query` or "
1263
- "`register_operations`."
1289
+ "`query`, `async_query`, `stream_query`, `async_stream_query`, "
1290
+ "`bidi_stream_query`, or ` register_operations`."
1264
1291
)
1265
1292
1266
1293
if is_queryable :
@@ -1299,6 +1326,15 @@ def _validate_agent_or_raise(
1299
1326
" missing `self` argument in the agent.async_stream_query method."
1300
1327
) from err
1301
1328
1329
+ if is_bidi_stream_queryable :
1330
+ try :
1331
+ inspect .signature (getattr (agent , "bidi_stream_query" ))
1332
+ except ValueError as err :
1333
+ raise ValueError (
1334
+ "Invalid bidi_stream_query signature. This might be due to a "
1335
+ " missing `self` argument in the agent.bidi_stream_query method."
1336
+ ) from err
1337
+
1302
1338
if is_operation_registrable :
1303
1339
try :
1304
1340
inspect .signature (getattr (agent , "register_operations" ))
0 commit comments