Skip to content

Commit 243462e

Browse files
rm-openaivrtnis
authored andcommitted
Update codex action to use repo syntax (#1218)
Action isn't published yet, so gotta do this
1 parent e0b1cdf commit 243462e

File tree

8 files changed

+383
-14
lines changed

8 files changed

+383
-14
lines changed

.github/workflows/codex.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
# Note it is possible that the `verify` step internal to Run Codex will
5454
# fail, in which case the work to setup the repo was worthless :(
5555
- name: Run Codex
56-
uses: openai/codex-action@latest
56+
uses: openai/codex/.github/actions/codex@main
5757
with:
5858
openai_api_key: ${{ secrets.PROD_OPENAI_API_KEY }}
5959
github_token: ${{ secrets.GITHUB_TOKEN }}

src/agents/agent.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
if TYPE_CHECKING:
2828
from .lifecycle import AgentHooks
2929
from .mcp import MCPServer
30-
from .result import RunResult
30+
from .result import RunResult, RunResultStreaming
3131

3232

3333
@dataclass
@@ -233,7 +233,9 @@ def as_tool(
233233
self,
234234
tool_name: str | None,
235235
tool_description: str | None,
236+
*,
236237
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
238+
stream_inner_events: bool = False,
237239
) -> Tool:
238240
"""Transform this agent into a tool, callable by other agents.
239241
@@ -258,17 +260,36 @@ def as_tool(
258260
async def run_agent(context: RunContextWrapper, input: str) -> str:
259261
from .run import Runner
260262

261-
output = await Runner.run(
262-
starting_agent=self,
263-
input=input,
264-
context=context.context,
265-
)
263+
output_run: RunResult | RunResultStreaming
264+
if stream_inner_events:
265+
from .stream_events import RunItemStreamEvent
266+
267+
sub_run = Runner.run_streamed(
268+
self,
269+
input=input,
270+
context=context.context,
271+
)
272+
parent_queue = getattr(context, "_event_queue", None)
273+
async for ev in sub_run.stream_events():
274+
if parent_queue is not None and isinstance(ev, RunItemStreamEvent):
275+
if ev.name in ("tool_called", "tool_output"):
276+
parent_queue.put_nowait(ev)
277+
output_run = sub_run
278+
else:
279+
output_run = await Runner.run(
280+
starting_agent=self,
281+
input=input,
282+
context=context.context,
283+
)
284+
266285
if custom_output_extractor:
267-
return await custom_output_extractor(output)
286+
return await custom_output_extractor(cast(Any, output_run))
268287

269-
return ItemHelpers.text_message_outputs(output.new_items)
288+
return ItemHelpers.text_message_outputs(output_run.new_items)
270289

271-
return run_agent
290+
tool = run_agent
291+
tool.stream_inner_events = stream_inner_events
292+
return tool
272293

273294
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
274295
"""Get the system prompt for the agent."""

src/agents/items.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def tool_call_output_item(
299299
) -> FunctionCallOutput:
300300
"""Creates a tool call output item from a tool call and its output."""
301301
return {
302-
"call_id": tool_call.call_id,
302+
"call_id": str(tool_call.call_id),
303303
"output": output,
304304
"type": "function_call_output",
305305
}

src/agents/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def run_streamed(
576576
trace=new_trace,
577577
context_wrapper=context_wrapper,
578578
)
579+
context_wrapper._event_queue = streamed_result._event_queue
579580

580581
# Kick off the actual agent loop in the background and return the streamed result object.
581582
streamed_result._run_impl_task = asyncio.create_task(

src/agents/run_context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import asyncio
12
from dataclasses import dataclass, field
2-
from typing import Any, Generic
3+
from typing import Any, Generic, Optional
34

45
from typing_extensions import TypeVar
56

@@ -24,3 +25,5 @@ class RunContextWrapper(Generic[TContext]):
2425
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
2526
last chunk of the stream is processed.
2627
"""
28+
29+
_event_queue: Optional[asyncio.Queue[Any]] = field(default=None, init=False, repr=False)

src/agents/tool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def __post_init__(self):
9797
if self.strict_json_schema:
9898
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
9999

100+
stream_inner_events: bool = False
101+
"""Whether to stream inner events when used as an agent tool."""
102+
100103

101104
@dataclass
102105
class FileSearchTool:

src/agents/tool_context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from dataclasses import dataclass, field, fields
23
from typing import Any, Optional
34

@@ -24,11 +25,13 @@ class ToolContext(RunContextWrapper[TContext]):
2425
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
2526
"""The ID of the tool call."""
2627

28+
_event_queue: asyncio.Queue[Any] | None = field(default=None, init=False, repr=False)
29+
2730
@classmethod
2831
def from_agent_context(
2932
cls,
3033
context: RunContextWrapper[TContext],
31-
tool_call_id: str,
34+
tool_call_id: str | int,
3235
tool_call: Optional[ResponseFunctionToolCall] = None,
3336
) -> "ToolContext":
3437
"""
@@ -39,4 +42,7 @@ def from_agent_context(
3942
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
4043
}
4144
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42-
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
45+
obj = cls(tool_name=tool_name, tool_call_id=str(tool_call_id), **base_values)
46+
if hasattr(context, "event_queue"):
47+
obj.event_queue = context.event_queue
48+
return obj

0 commit comments

Comments
 (0)