|
27 | 27 | if TYPE_CHECKING: |
28 | 28 | from .lifecycle import AgentHooks |
29 | 29 | from .mcp import MCPServer |
30 | | - from .result import RunResult |
| 30 | + from .result import RunResult, RunResultStreaming |
31 | 31 |
|
32 | 32 |
|
33 | 33 | @dataclass |
@@ -356,9 +356,14 @@ def as_tool( |
356 | 356 | self, |
357 | 357 | tool_name: str | None, |
358 | 358 | tool_description: str | None, |
| 359 | + *, |
359 | 360 | custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, |
| 361 | +<<<<<<< HEAD |
360 | 362 | is_enabled: bool |
361 | 363 | | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, |
| 364 | +======= |
| 365 | + stream_inner_events: bool = False, |
| 366 | +>>>>>>> 243462e (Update codex action to use repo syntax (#1218)) |
362 | 367 | ) -> Tool: |
363 | 368 | """Transform this agent into a tool, callable by other agents. |
364 | 369 |
|
@@ -387,17 +392,36 @@ def as_tool( |
387 | 392 | async def run_agent(context: RunContextWrapper, input: str) -> str: |
388 | 393 | from .run import Runner |
389 | 394 |
|
390 | | - output = await Runner.run( |
391 | | - starting_agent=self, |
392 | | - input=input, |
393 | | - context=context.context, |
394 | | - ) |
| 395 | + output_run: RunResult | RunResultStreaming |
| 396 | + if stream_inner_events: |
| 397 | + from .stream_events import RunItemStreamEvent |
| 398 | + |
| 399 | + sub_run = Runner.run_streamed( |
| 400 | + self, |
| 401 | + input=input, |
| 402 | + context=context.context, |
| 403 | + ) |
| 404 | + parent_queue = getattr(context, "_event_queue", None) |
| 405 | + async for ev in sub_run.stream_events(): |
| 406 | + if parent_queue is not None and isinstance(ev, RunItemStreamEvent): |
| 407 | + if ev.name in ("tool_called", "tool_output"): |
| 408 | + parent_queue.put_nowait(ev) |
| 409 | + output_run = sub_run |
| 410 | + else: |
| 411 | + output_run = await Runner.run( |
| 412 | + starting_agent=self, |
| 413 | + input=input, |
| 414 | + context=context.context, |
| 415 | + ) |
| 416 | + |
395 | 417 | if custom_output_extractor: |
396 | | - return await custom_output_extractor(output) |
| 418 | + return await custom_output_extractor(cast(Any, output_run)) |
397 | 419 |
|
398 | | - return ItemHelpers.text_message_outputs(output.new_items) |
| 420 | + return ItemHelpers.text_message_outputs(output_run.new_items) |
399 | 421 |
|
400 | | - return run_agent |
| 422 | + tool = run_agent |
| 423 | + tool.stream_inner_events = stream_inner_events |
| 424 | + return tool |
401 | 425 |
|
402 | 426 | async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: |
403 | 427 | if isinstance(self.instructions, str): |
|
0 commit comments