Skip to content

Commit 1a3d953

Browse files
rm-openaivrtnis
authored andcommitted
Update codex action to use repo syntax (openai#1218)
Action isn't published yet, so gotta do this
1 parent 594efb0 commit 1a3d953

File tree

7 files changed

+385
-13
lines changed

7 files changed

+385
-13
lines changed

src/agents/agent.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
if TYPE_CHECKING:
2828
from .lifecycle import AgentHooks
2929
from .mcp import MCPServer
30-
from .result import RunResult
30+
from .result import RunResult, RunResultStreaming
3131

3232

3333
@dataclass
@@ -356,9 +356,14 @@ def as_tool(
356356
self,
357357
tool_name: str | None,
358358
tool_description: str | None,
359+
*,
359360
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
361+
<<<<<<< HEAD
360362
is_enabled: bool
361363
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
364+
=======
365+
stream_inner_events: bool = False,
366+
>>>>>>> 243462e (Update codex action to use repo syntax (#1218))
362367
) -> Tool:
363368
"""Transform this agent into a tool, callable by other agents.
364369
@@ -387,17 +392,36 @@ def as_tool(
387392
async def run_agent(context: RunContextWrapper, input: str) -> str:
388393
from .run import Runner
389394

390-
output = await Runner.run(
391-
starting_agent=self,
392-
input=input,
393-
context=context.context,
394-
)
395+
output_run: RunResult | RunResultStreaming
396+
if stream_inner_events:
397+
from .stream_events import RunItemStreamEvent
398+
399+
sub_run = Runner.run_streamed(
400+
self,
401+
input=input,
402+
context=context.context,
403+
)
404+
parent_queue = getattr(context, "_event_queue", None)
405+
async for ev in sub_run.stream_events():
406+
if parent_queue is not None and isinstance(ev, RunItemStreamEvent):
407+
if ev.name in ("tool_called", "tool_output"):
408+
parent_queue.put_nowait(ev)
409+
output_run = sub_run
410+
else:
411+
output_run = await Runner.run(
412+
starting_agent=self,
413+
input=input,
414+
context=context.context,
415+
)
416+
395417
if custom_output_extractor:
396-
return await custom_output_extractor(output)
418+
return await custom_output_extractor(cast(Any, output_run))
397419

398-
return ItemHelpers.text_message_outputs(output.new_items)
420+
return ItemHelpers.text_message_outputs(output_run.new_items)
399421

400-
return run_agent
422+
tool = run_agent
423+
tool.stream_inner_events = stream_inner_events
424+
return tool
401425

402426
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
403427
if isinstance(self.instructions, str):

src/agents/items.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def tool_call_output_item(
302302
) -> FunctionCallOutput:
303303
"""Creates a tool call output item from a tool call and its output."""
304304
return {
305-
"call_id": tool_call.call_id,
305+
"call_id": str(tool_call.call_id),
306306
"output": output,
307307
"type": "function_call_output",
308308
}

src/agents/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def run_streamed(
607607
trace=new_trace,
608608
context_wrapper=context_wrapper,
609609
)
610+
context_wrapper._event_queue = streamed_result._event_queue
610611

611612
# Kick off the actual agent loop in the background and return the streamed result object.
612613
streamed_result._run_impl_task = asyncio.create_task(

src/agents/run_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import asyncio
12
from dataclasses import dataclass, field
2-
from typing import Any, Generic
3+
from typing import Any, Generic, Optional
34

45
from typing_extensions import TypeVar
56

@@ -24,3 +25,5 @@ class RunContextWrapper(Generic[TContext]):
2425
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
2526
last chunk of the stream is processed.
2627
"""
28+
29+
_event_queue: Optional[asyncio.Queue[Any]] = field(default=None, init=False, repr=False)

src/agents/tool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __post_init__(self):
9797
if self.strict_json_schema:
9898
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
9999

100+
stream_inner_events: bool = False
101+
"""Whether to stream inner events when used as an agent tool."""
102+
100103

101104
@dataclass
102105
class FileSearchTool:

src/agents/tool_context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass, field, fields
23
from typing import Any, Optional
34

@@ -24,11 +25,13 @@ class ToolContext(RunContextWrapper[TContext]):
2425
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
2526
"""The ID of the tool call."""
2627

28+
_event_queue: asyncio.Queue[Any] | None = field(default=None, init=False, repr=False)
29+
2730
@classmethod
2831
def from_agent_context(
2932
cls,
3033
context: RunContextWrapper[TContext],
31-
tool_call_id: str,
34+
tool_call_id: str | int,
3235
tool_call: Optional[ResponseFunctionToolCall] = None,
3336
) -> "ToolContext":
3437
"""
@@ -39,4 +42,7 @@ def from_agent_context(
3942
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
4043
}
4144
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42-
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
45+
obj = cls(tool_name=tool_name, tool_call_id=str(tool_call_id), **base_values)
46+
if hasattr(context, "event_queue"):
47+
obj.event_queue = context.event_queue
48+
return obj

0 commit comments

Comments
 (0)