Skip to content

[Frontend] Added support for HermesToolParser for models without special tokens #16890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

import pytest

from ....utils import RemoteOpenAIServer

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"

SERVER_ARGS = [
"--enforce-eager",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--enable-lora",
"--lora-modules",
f"{LORA_MODEL}={LORA_MODEL}",
]

TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
},
},
"required": ["location"],
},
},
}]

MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]


@pytest.mark.asyncio
async def test_non_streaming_tool_call():
"""Test tool call in non-streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()

response = await client.chat.completions.create(
model=LORA_MODEL,
messages=MESSAGES,
tools=TOOLS,
tool_choice="auto",
temperature=0.0,
)

assert response.choices
choice = response.choices[0]
message = choice.message

assert choice.finish_reason == "tool_calls"
assert message.tool_calls is not None

tool_call = message.tool_calls[0]
assert tool_call.type == "function"
assert tool_call.function.name == "get_current_weather"

arguments = json.loads(tool_call.function.arguments)
assert "location" in arguments
assert "Boston" in arguments["location"]
print("\n[Non-Streaming Test Passed]")
print(f"Tool Call: {tool_call.function.name}")
print(f"Arguments: {arguments}")


@pytest.mark.asyncio
async def test_streaming_tool_call():
"""Test tool call in streaming mode."""
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
client = server.get_async_client()

stream = await client.chat.completions.create(
model=LORA_MODEL,
messages=MESSAGES,
tools=TOOLS,
tool_choice="auto",
temperature=0.0,
stream=True,
)

tool_call_chunks = {}
async for chunk in stream:
if not chunk.choices:
continue

delta = chunk.choices[0].delta
if not delta or not delta.tool_calls:
continue

for tool_chunk in delta.tool_calls:
index = tool_chunk.index
if index not in tool_call_chunks:
tool_call_chunks[index] = {"name": "", "arguments": ""}

if tool_chunk.function.name:
tool_call_chunks[index]["name"] += tool_chunk.function.name
if tool_chunk.function.arguments:
tool_call_chunks[index][
"arguments"] += tool_chunk.function.arguments

assert len(tool_call_chunks) == 1
reconstructed_tool_call = tool_call_chunks[0]

assert reconstructed_tool_call["name"] == "get_current_weather"

arguments = json.loads(reconstructed_tool_call["arguments"])
assert "location" in arguments
assert "Boston" in arguments["location"]
print("\n[Streaming Test Passed]")
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
print(f"Reconstructed Arguments: {arguments}")
81 changes: 64 additions & 17 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,51 @@ def __init__(self, tokenizer: AnyTokenizer):
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
self.tool_call_start_token_ids = self.model_tokenizer.encode(
self.tool_call_start_token, add_special_tokens=False)
self.tool_call_end_token_ids = self.model_tokenizer.encode(
self.tool_call_end_token, add_special_tokens=False)

self.tool_call_start_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_start_token_ids
]

self.tool_call_end_token_array = [
self.model_tokenizer.decode([token_id])
for token_id in self.tool_call_end_token_ids
]

self.buffered_delta_text = ""

# Very simple idea: when encountering tokens like <, tool, _call, >,
# <, /, tool, _call, >, store them in a buffer.
# When the last token is encountered, empty the buffer and return it.
# If a token appears in an incorrect sequence while storing in the buffer,
# return the preceding buffer along with the token.
def tool_call_delta_buffer(self, delta_text: str):
# If the sequence of tool_call_start or tool_call_end tokens is not yet
# complete, fill the buffer with the token and return "".
if (delta_text in self.tool_call_start_token_array
or delta_text in self.tool_call_end_token_array):
# If delta_text is the last token of tool_call_start_token or
# tool_call_end_token, empty the buffer and return
# the buffered text + delta_text.
if (delta_text == self.tool_call_start_token_array[-1]
or delta_text == self.tool_call_end_token_array[-1]):
buffered_text = self.buffered_delta_text
self.buffered_delta_text = ""
return buffered_text + delta_text
else:
self.buffered_delta_text = self.buffered_delta_text + delta_text
return ""
else:
if self.buffered_delta_text:
buffered_text = self.buffered_delta_text
self.buffered_delta_text = ""
return buffered_text + delta_text
else:
return delta_text

def extract_tool_calls(
self,
Expand Down Expand Up @@ -124,26 +161,36 @@ def extract_tool_calls_streaming(
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# 1. All tokens are parsed based on _text, not token_ids.
# 2. All incoming text data is processed by the tool_call_delta_buffer
# function for buffering before being used for parsing.

delta_text = self.tool_call_delta_buffer(delta_text)
# If the last characters of previous_text
# match self.buffered_delta_text, remove only the matching part.
if (len(previous_text) >= len(self.buffered_delta_text)
and previous_text[-len(self.buffered_delta_text):]
== self.buffered_delta_text):
previous_text = previous_text[:-len(self.buffered_delta_text)]
current_text = previous_text + delta_text

logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token_id not in current_token_ids:
if self.tool_call_start_token not in current_text:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)

try:

# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
prev_tool_start_count = previous_text.count(
self.tool_call_start_token)
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
cur_tool_start_count = current_text.count(
self.tool_call_start_token)
cur_tool_end_count = current_text.count(self.tool_call_end_token)
tool_call_portion = None
text_portion = None

Expand Down