Skip to content

Commit 64283c6

Browse files
committed
Start and finish streaming trace in impl metod
1 parent 5639606 commit 64283c6

File tree

9 files changed

+131
-9
lines changed

9 files changed

+131
-9
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,5 @@ cython_debug/
141141
.ruff_cache/
142142

143143
# PyPI configuration file
144-
.pypirc
144+
.pypirc
145+
.aider*
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
4+
from agents import Agent, Runner
5+
6+
"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
7+
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.
8+
9+
In this example, we define an output type that is not strict-compatible, and then we run the
10+
agent with strict_json_schema=False.
11+
12+
To understand which schemas are strict-compatible, see:
13+
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
14+
"""
15+
16+
17+
@dataclass
18+
class OutputType:
19+
jokes: dict[int, str]
20+
"""A list of jokes, indexed by joke number."""
21+
22+
23+
async def main():
24+
agent = Agent(
25+
name="Assistant",
26+
instructions="You are a helpful assistant.",
27+
output_type=OutputType,
28+
)
29+
30+
input = "Tell me 3 short jokes."
31+
32+
# First, let's try with a strict output type. This should raise an exception.
33+
try:
34+
result = await Runner.run(agent, input)
35+
raise AssertionError("Should have raised an exception")
36+
except Exception as e:
37+
print(f"Error: {e}")
38+
39+
# Now let's try again with a non-strict output type. This should work.
40+
agent.output_schema_strict = False
41+
result = await Runner.run(agent, input)
42+
print(result.final_output)
43+
44+
45+
if __name__ == "__main__":
46+
asyncio.run(main())

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dev = [
6161
"graphviz",
6262
"mkdocs-static-i18n>=1.3.0",
6363
"eval-type-backport>=0.2.2",
64+
"fastapi >= 0.110.0, <1",
6465
]
6566

6667
[tool.uv.workspace]

src/agents/result.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,6 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:
185185
yield item
186186
self._event_queue.task_done()
187187

188-
if self._trace:
189-
self._trace.finish(reset_current=True)
190-
191188
self._cleanup_tasks()
192189

193190
if self._stored_exception:

src/agents/run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,6 @@ def run_streamed(
404404
disabled=run_config.tracing_disabled,
405405
)
406406
)
407-
# Need to start the trace here, because the current trace contextvar is captured at
408-
# asyncio.create_task time
409-
if new_trace:
410-
new_trace.start(mark_as_current=True)
411407

412408
output_schema = cls._get_output_schema(starting_agent)
413409
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
@@ -499,6 +495,9 @@ async def _run_streamed_impl(
499495
run_config: RunConfig,
500496
previous_response_id: str | None,
501497
):
498+
if streamed_result._trace:
499+
streamed_result._trace.start(mark_as_current=True)
500+
502501
current_span: Span[AgentSpanData] | None = None
503502
current_agent = starting_agent
504503
current_turn = 0
@@ -625,6 +624,8 @@ async def _run_streamed_impl(
625624
finally:
626625
if current_span:
627626
current_span.finish(reset_current=True)
627+
if streamed_result._trace:
628+
streamed_result._trace.finish(reset_current=True)
628629

629630
@classmethod
630631
async def _run_single_turn_streamed(

tests/fastapi/__init__.py

Whitespace-only changes.

tests/fastapi/streaming_app.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from collections.abc import AsyncIterator
2+
3+
from fastapi import FastAPI
4+
from starlette.responses import StreamingResponse
5+
6+
from agents import Agent, Runner, RunResultStreaming
7+
8+
agent = Agent(
9+
name="Assistant",
10+
instructions="You are a helpful assistant.",
11+
)
12+
13+
14+
app = FastAPI()
15+
16+
17+
@app.post("/stream")
18+
async def stream():
19+
result = Runner.run_streamed(agent, input="Tell me a joke")
20+
stream_handler = StreamHandler(result)
21+
return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson")
22+
23+
24+
class StreamHandler:
25+
def __init__(self, result: RunResultStreaming):
26+
self.result = result
27+
28+
async def stream_events(self) -> AsyncIterator[str]:
29+
async for event in self.result.stream_events():
30+
yield f"{event.type}\n\n"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from httpx import ASGITransport, AsyncClient
3+
from inline_snapshot import snapshot
4+
5+
from ..fake_model import FakeModel
6+
from ..test_responses import get_text_message
7+
from .streaming_app import agent, app
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_streaming_context():
12+
"""This ensures that FastAPI streaming works. The context for this test is that the Runner
13+
method was called in one async context, and the streaming was ended in another context,
14+
leading to a tracing error because the context was closed in the wrong context. This test
15+
ensures that this actually works.
16+
"""
17+
model = FakeModel()
18+
agent.model = model
19+
model.set_next_output([get_text_message("done")])
20+
21+
transport = ASGITransport(app)
22+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
23+
async with ac.stream("POST", "/stream") as r:
24+
assert r.status_code == 200
25+
body = (await r.aread()).decode("utf-8")
26+
lines = [line for line in body.splitlines() if line]
27+
assert lines == snapshot(
28+
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
29+
)

uv.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)