Skip to content

Add Streaming of Function Call Arguments to Chat Completions #999

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 134 additions & 43 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class StreamingState:
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
# Fields for real-time function call streaming
function_call_streaming: dict[int, bool] = field(default_factory=dict)
function_call_output_idx: dict[int, int] = field(default_factory=dict)


class SequenceNumber:
Expand Down Expand Up @@ -255,9 +258,7 @@ async def handle_stream(
# Accumulate the refusal string in the output part
state.refusal_content_index_and_output[1].refusal += delta.refusal

# Handle tool calls
# Because we don't know the name of the function until the end of the stream, we'll
# save everything and yield events at the end
# Handle tool calls with real-time streaming support
if delta.tool_calls:
for tc_delta in delta.tool_calls:
if tc_delta.index not in state.function_calls:
Expand All @@ -268,15 +269,76 @@ async def handle_stream(
type="function_call",
call_id="",
)
state.function_call_streaming[tc_delta.index] = False

tc_function = tc_delta.function

# Accumulate arguments as they come in
state.function_calls[tc_delta.index].arguments += (
tc_function.arguments if tc_function else ""
) or ""
state.function_calls[tc_delta.index].name += (
tc_function.name if tc_function else ""
) or ""
state.function_calls[tc_delta.index].call_id = tc_delta.id or ""

# Set function name directly (it's correct from the first function call chunk)
if tc_function and tc_function.name:
state.function_calls[tc_delta.index].name = tc_function.name

if tc_delta.id:
state.function_calls[tc_delta.index].call_id = tc_delta.id

function_call = state.function_calls[tc_delta.index]

# Start streaming as soon as we have function name and call_id
if (not state.function_call_streaming[tc_delta.index] and
function_call.name and
function_call.call_id):

# Calculate the output index for this function call
function_call_starting_index = 0
if state.reasoning_content_index_and_output:
function_call_starting_index += 1
if state.text_content_index_and_output:
function_call_starting_index += 1
if state.refusal_content_index_and_output:
function_call_starting_index += 1

# Add offset for already started function calls
function_call_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Mark this function call as streaming and store its output index
state.function_call_streaming[tc_delta.index] = True
state.function_call_output_idx[
tc_delta.index
] = function_call_starting_index

# Send initial function call added event
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments="", # Start with empty arguments
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)

# Stream arguments if we've started streaming this function call
if (state.function_call_streaming.get(tc_delta.index, False) and
tc_function and
tc_function.arguments):

output_index = state.function_call_output_idx[tc_delta.index]
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=tc_function.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=output_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)

if state.reasoning_content_index_and_output:
yield ResponseReasoningSummaryPartDoneEvent(
Expand Down Expand Up @@ -327,42 +389,71 @@ async def handle_stream(
sequence_number=sequence_number.get_and_increment(),
)

# Actually send events for the function calls
for function_call in state.function_calls.values():
# First, a ResponseOutputItemAdded for the function call
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
# Then, yield the args
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=function_call_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
# Finally, the ResponseOutputItemDone
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=function_call_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
# Send completion events for function calls
for index, function_call in state.function_calls.items():
if state.function_call_streaming.get(index, False):
# Function call was streamed, just send the completion event
output_index = state.function_call_output_idx[index]
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=output_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
else:
# Function call was not streamed (fallback to old behavior)
# This handles edge cases where function name never arrived
fallback_starting_index = 0
if state.reasoning_content_index_and_output:
fallback_starting_index += 1
if state.text_content_index_and_output:
fallback_starting_index += 1
if state.refusal_content_index_and_output:
fallback_starting_index += 1

# Add offset for already started function calls
fallback_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Send all events at once (backward compatibility)
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=fallback_starting_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(
id=FAKE_RESPONSES_ID,
call_id=function_call.call_id,
arguments=function_call.arguments,
name=function_call.name,
type="function_call",
),
output_index=fallback_starting_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)

# Finally, send the Response completed event
outputs: list[ResponseOutputItem] = []
Expand Down
148 changes: 131 additions & 17 deletions tests/models/test_litellm_chatcompletions_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,17 +214,18 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
the model is streaming a function/tool call instead of plain text.
The function call will be split across two chunks.
"""
# Simulate a single tool call whose ID stays constant and function name/args built over chunks.
# Simulate a single tool call with complete function name in first chunk
# and arguments split across chunks (reflecting real API behavior)
tool_call_delta1 = ChoiceDeltaToolCall(
index=0,
id="tool-id",
function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"),
function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"),
type="function",
)
tool_call_delta2 = ChoiceDeltaToolCall(
index=0,
id="tool-id",
function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"),
function=ChoiceDeltaToolCallFunction(name=None, arguments="arg2"),
type="function",
)
chunk1 = ChatCompletionChunk(
Expand Down Expand Up @@ -284,18 +285,131 @@ async def patched_fetch_response(self, *args, **kwargs):
# The added item should be a ResponseFunctionToolCall.
added_fn = output_events[1].item
assert isinstance(added_fn, ResponseFunctionToolCall)
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
assert added_fn.arguments == "arg1arg2"
assert added_fn.name == "my_func" # Name should be complete from first chunk
assert added_fn.arguments == "" # Arguments start empty
assert output_events[2].type == "response.function_call_arguments.delta"
assert output_events[2].delta == "arg1arg2"
assert output_events[3].type == "response.output_item.done"
assert output_events[4].type == "response.completed"
assert output_events[2].delta == "arg1arg2"
assert output_events[3].type == "response.output_item.done"
assert output_events[4].type == "response.completed"
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
assert added_fn.arguments == "arg1arg2"
assert output_events[2].type == "response.function_call_arguments.delta"
assert output_events[2].delta == "arg1arg2"
assert output_events[3].type == "response.output_item.done"
assert output_events[4].type == "response.completed"
assert output_events[2].delta == "arg1" # First argument chunk
assert output_events[3].type == "response.function_call_arguments.delta"
assert output_events[3].delta == "arg2" # Second argument chunk
assert output_events[4].type == "response.output_item.done"
assert output_events[5].type == "response.completed"
# Final function call should have complete arguments
final_fn = output_events[4].item
assert isinstance(final_fn, ResponseFunctionToolCall)
assert final_fn.name == "my_func"
assert final_fn.arguments == "arg1arg2"


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None:
"""
Validate that LiteLLM `stream_response` also emits function call arguments in real-time
as they are received, ensuring consistent behavior across model providers.
"""
# Simulate realistic chunks: name first, then arguments incrementally
tool_call_delta1 = ChoiceDeltaToolCall(
index=0,
id="litellm-call-456",
function=ChoiceDeltaToolCallFunction(name="generate_code", arguments=""),
type="function",
)
tool_call_delta2 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='{"language": "'),
type="function",
)
tool_call_delta3 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='python", "task": "'),
type="function",
)
tool_call_delta4 = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(arguments='hello world"}'),
type="function",
)

chunk1 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
)
chunk2 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
)
chunk3 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))],
)
chunk4 = ChatCompletionChunk(
id="chunk-id",
created=1,
model="fake",
object="chat.completion.chunk",
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))],
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
)

async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
for c in (chunk1, chunk2, chunk3, chunk4):
yield c

async def patched_fetch_response(self, *args, **kwargs):
resp = Response(
id="resp-id",
created_at=0,
model="fake-model",
object="response",
output=[],
tool_choice="none",
tools=[],
parallel_tool_calls=False,
)
return resp, fake_stream()

monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
model = LitellmProvider().get_model("gpt-4")
output_events = []
async for event in model.stream_response(
system_instructions=None,
input="",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None,
):
output_events.append(event)

# Extract events by type
function_args_delta_events = [
e for e in output_events if e.type == "response.function_call_arguments.delta"
]
output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"]

# Verify we got real-time streaming (3 argument delta events)
assert len(function_args_delta_events) == 3
assert len(output_item_added_events) == 1

# Verify the deltas were streamed correctly
expected_deltas = ['{"language": "', 'python", "task": "', 'hello world"}']
for i, delta_event in enumerate(function_args_delta_events):
assert delta_event.delta == expected_deltas[i]

# Verify function call metadata
added_event = output_item_added_events[0]
assert isinstance(added_event.item, ResponseFunctionToolCall)
assert added_event.item.name == "generate_code"
assert added_event.item.call_id == "litellm-call-456"
Loading