Skip to content

Initialize the delta tool call fields explicitly #17340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/tool_parsers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def append_delta(self, delta: DeltaMessage):
assert len(delta.tool_calls) < 2, (
"Streaming should include only one tool call per update.")
for call_delta in delta.tool_calls:
assert call_delta.type == "function", (
assert call_delta.type is None or call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}")
current_tool_call = self.tool_calls[
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -1258,3 +1259,6 @@ def apply_mistral_chat_template(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e

def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"
9 changes: 5 additions & 4 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from typing_extensions import TypeAlias

from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id)
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
Expand Down Expand Up @@ -1314,7 +1315,7 @@ class FunctionCall(OpenAIBaseModel):


class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
id: str = Field(default_factory=random_tool_call_id)
type: Literal["function"] = "function"
function: FunctionCall

Expand All @@ -1326,8 +1327,8 @@ class DeltaFunctionCall(BaseModel):

# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
id: Optional[str] = None
type: Optional[Literal["function"]] = None
index: int
function: Optional[DeltaFunctionCall] = None

Expand Down
39 changes: 25 additions & 14 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage)
ConversationMessage,
random_tool_call_id)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
Expand Down Expand Up @@ -363,9 +364,10 @@ def extract_tool_call_required_streaming(

function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
DeltaToolCall(id=random_tool_call_id(),
function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
Expand All @@ -382,8 +384,7 @@ def extract_tool_call_required_streaming(
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
index=len(obj) - 1)
])
else:
delta_message = None
Expand Down Expand Up @@ -422,7 +423,7 @@ async def chat_completion_stream_generator(
and self._should_stream_with_auto_tool_parsing(request))

all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
function_name_returned = [False] * num_choices

# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
Expand All @@ -435,7 +436,6 @@ async def chat_completion_stream_generator(
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
Expand Down Expand Up @@ -623,16 +623,27 @@ async def chat_completion_stream_generator(
delta_text = previous_text + delta_text
current_text = ""

if function_name_returned[i]:
delta_tool_call = DeltaToolCall(
function=DeltaFunctionCall(
arguments=delta_text),
index=i)
else:
delta_tool_call = DeltaToolCall(
id=random_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
function_name_returned[i] = True

delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
delta_tool_call,
])

elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
Expand Down Expand Up @@ -835,7 +846,7 @@ async def chat_completion_stream_generator(
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
data = chunk.model_dump_json(exclude_none=True)
yield f"data: {data}\n\n"

# once the final token is handled, if stream_options.include_usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -22,7 +23,6 @@
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -200,7 +200,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -20,7 +21,6 @@
partial_json_loads)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -182,7 +182,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -17,7 +18,6 @@
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -259,7 +259,7 @@ def extract_tool_calls_streaming(
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -18,7 +19,6 @@
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -106,7 +106,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -19,7 +20,6 @@
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -220,7 +220,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -21,7 +22,6 @@
is_complete_json,
partial_json_loads)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -208,7 +208,7 @@ def extract_tool_calls_streaming(
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.utils import random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -73,7 +73,7 @@ def extract_tool_calls(

tool_calls: list[ToolCall] = [
ToolCall(
id=f"chatcmpl-tool-{random_uuid()}",
id=random_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
new_call_args = new_call_args[:-len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
Expand All @@ -288,5 +289,5 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,

arg_diff = new_call_args[len(previously_sent_args):]
return DeltaToolCall(
id="", index=index, function=DeltaFunctionCall(
id=None, index=index, function=DeltaFunctionCall(
arguments=arg_diff)) if arg_diff else None