Skip to content

Commit 18cd84c

Browse files
authored
Merge branch 'main' into pr2-mcp-core
2 parents 66cfd6c + 21bb323 commit 18cd84c

File tree

4 files changed

+383
-0
lines changed

4 files changed

+383
-0
lines changed

docs/features/tool_calling.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,19 @@ Supported models:
376376

377377
Flags: `--tool-call-parser olmo3`
378378

379+
### Gigachat 3 Models (`gigachat3`)
380+
381+
Use chat template from the Hugging Face model files.
382+
383+
Supported models:
384+
385+
* `ai-sage/GigaChat3-702B-A36B-preview`
386+
* `ai-sage/GigaChat3-702B-A36B-preview-bf16`
387+
* `ai-sage/GigaChat3-10B-A1.8B`
388+
* `ai-sage/GigaChat3-10B-A1.8B-bf16`
389+
390+
Flags: `--tool-call-parser gigachat3`
391+
379392
### Models with Pythonic Tool Calls (`pythonic`)
380393

381394
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
6+
import pytest
7+
8+
from tests.entrypoints.openai.tool_parsers.utils import (
9+
run_tool_extraction,
10+
run_tool_extraction_streaming,
11+
)
12+
from vllm.entrypoints.openai.protocol import FunctionCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
from vllm.tokenizers import TokenizerLike
15+
16+
SIMPLE_ARGS_DICT = {
17+
"action": "create",
18+
"id": "preferences",
19+
}
20+
SIMPLE_FUNCTION_JSON = json.dumps(
21+
{
22+
"name": "manage_user_memory",
23+
"arguments": SIMPLE_ARGS_DICT,
24+
},
25+
ensure_ascii=False,
26+
)
27+
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
28+
SIMPLE_FUNCTION_CALL = FunctionCall(
29+
name="manage_user_memory",
30+
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
31+
)
32+
33+
34+
PARAMETERLESS_FUNCTION_JSON = json.dumps(
35+
{
36+
"name": "manage_user_memory",
37+
"arguments": {},
38+
},
39+
ensure_ascii=False,
40+
)
41+
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
42+
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
43+
name="manage_user_memory",
44+
arguments=json.dumps({}, ensure_ascii=False),
45+
)
46+
47+
48+
COMPLEX_ARGS_DICT = {
49+
"action": "create",
50+
"id": "preferences",
51+
"content": {
52+
"short_answers": True,
53+
"hate_emojis": True,
54+
"english_ui": False,
55+
"russian_math_explanations": True,
56+
},
57+
}
58+
COMPLEX_FUNCTION_JSON = json.dumps(
59+
{
60+
"name": "manage_user_memory",
61+
"arguments": COMPLEX_ARGS_DICT,
62+
},
63+
ensure_ascii=False,
64+
)
65+
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
66+
COMPLEX_FUNCTION_CALL = FunctionCall(
67+
name="manage_user_memory",
68+
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
69+
)
70+
71+
72+
@pytest.mark.parametrize("streaming", [True, False])
73+
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
74+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
75+
default_tokenizer
76+
)
77+
model_output = "How can I help you today?"
78+
content, tool_calls = run_tool_extraction(
79+
tool_parser, model_output, streaming=streaming
80+
)
81+
assert content == model_output
82+
assert len(tool_calls) == 0
83+
84+
85+
TEST_CASES = [
86+
pytest.param(
87+
True,
88+
SIMPLE_FUNCTION_OUTPUT,
89+
[SIMPLE_FUNCTION_CALL],
90+
None,
91+
id="simple_streaming",
92+
),
93+
pytest.param(
94+
False,
95+
SIMPLE_FUNCTION_OUTPUT,
96+
[SIMPLE_FUNCTION_CALL],
97+
None,
98+
id="simple_nonstreaming",
99+
),
100+
pytest.param(
101+
True,
102+
PARAMETERLESS_FUNCTION_OUTPUT,
103+
[PARAMETERLESS_FUNCTION_CALL],
104+
None,
105+
id="parameterless_streaming",
106+
),
107+
pytest.param(
108+
False,
109+
PARAMETERLESS_FUNCTION_OUTPUT,
110+
[PARAMETERLESS_FUNCTION_CALL],
111+
None,
112+
id="parameterless_nonstreaming",
113+
),
114+
pytest.param(
115+
True,
116+
COMPLEX_FUNCTION_OUTPUT,
117+
[COMPLEX_FUNCTION_CALL],
118+
None,
119+
id="complex_streaming",
120+
),
121+
pytest.param(
122+
False,
123+
COMPLEX_FUNCTION_OUTPUT,
124+
[COMPLEX_FUNCTION_CALL],
125+
None,
126+
id="complex_nonstreaming",
127+
),
128+
]
129+
130+
131+
@pytest.mark.parametrize(
132+
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
133+
)
134+
def test_tool_call(
135+
streaming: bool,
136+
model_output: str,
137+
expected_tool_calls: list[FunctionCall],
138+
expected_content: str | None,
139+
default_tokenizer: TokenizerLike,
140+
):
141+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
142+
default_tokenizer
143+
)
144+
content, tool_calls = run_tool_extraction(
145+
tool_parser, model_output, streaming=streaming
146+
)
147+
assert content == expected_content
148+
assert len(tool_calls) == len(expected_tool_calls)
149+
for actual, expected in zip(tool_calls, expected_tool_calls):
150+
assert actual.type == "function"
151+
assert actual.function.name == expected.name
152+
actual_args = json.loads(actual.function.arguments)
153+
expected_args = json.loads(expected.arguments)
154+
assert actual_args == expected_args
155+
156+
157+
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
158+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
159+
default_tokenizer
160+
)
161+
model_output_deltas = [
162+
"function call",
163+
COMPLEX_FUNCTION_JSON[:40],
164+
COMPLEX_FUNCTION_JSON[40:],
165+
]
166+
reconstructor = run_tool_extraction_streaming(
167+
tool_parser,
168+
model_output_deltas,
169+
assert_one_tool_per_delta=False,
170+
)
171+
assert len(reconstructor.tool_calls) == 1
172+
call = reconstructor.tool_calls[0]
173+
assert call.type == "function"
174+
assert call.function.name == "manage_user_memory"
175+
args_dict = json.loads(call.function.arguments)
176+
assert args_dict == COMPLEX_ARGS_DICT

vllm/entrypoints/openai/tool_parsers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@
134134
"xlam_tool_parser",
135135
"xLAMToolParser",
136136
),
137+
"gigachat3": (
138+
"gigachat3_tool_parser",
139+
"GigaChat3ToolParser",
140+
),
137141
}
138142

139143

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
from collections.abc import Sequence
6+
7+
import regex as re
8+
9+
from vllm.entrypoints.chat_utils import make_tool_call_id
10+
from vllm.entrypoints.openai.protocol import (
11+
ChatCompletionRequest,
12+
DeltaFunctionCall,
13+
DeltaMessage,
14+
DeltaToolCall,
15+
ExtractedToolCallInformation,
16+
FunctionCall,
17+
ToolCall,
18+
)
19+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
20+
from vllm.logger import init_logger
21+
from vllm.tokenizers import TokenizerLike
22+
23+
logger = init_logger(__name__)
24+
25+
REGEX_FUNCTION_CALL = re.compile(
26+
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
27+
re.DOTALL,
28+
)
29+
30+
NAME_REGEX = re.compile(
31+
r'"name"\s*:\s*"([^"]*)"',
32+
re.DOTALL,
33+
)
34+
35+
ARGS_REGEX = re.compile(
36+
r'"arguments"\s*:\s*(.*)',
37+
re.DOTALL,
38+
)
39+
40+
41+
class GigaChat3ToolParser(ToolParser):
42+
def __init__(self, tokenizer: TokenizerLike):
43+
super().__init__(tokenizer)
44+
self.tool_started: bool = False
45+
self.tool_name_sent: bool = False
46+
self.tool_id: str | None = None
47+
self.prev_tool_call_arr: list[dict] = []
48+
self.content_buffer: str = ""
49+
self.trigger_start = "function call{"
50+
51+
def extract_tool_calls(
52+
self,
53+
model_output: str,
54+
request: ChatCompletionRequest,
55+
) -> ExtractedToolCallInformation:
56+
match = REGEX_FUNCTION_CALL.search(model_output)
57+
if not match:
58+
return ExtractedToolCallInformation(
59+
tools_called=False,
60+
tool_calls=[],
61+
content=model_output,
62+
)
63+
json_candidate = match.group(1).strip()
64+
try:
65+
data = json.loads(json_candidate)
66+
except json.JSONDecodeError:
67+
return ExtractedToolCallInformation(
68+
tools_called=False,
69+
tool_calls=[],
70+
content=model_output,
71+
)
72+
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
73+
return ExtractedToolCallInformation(
74+
tools_called=False,
75+
tool_calls=[],
76+
content=model_output,
77+
)
78+
name = data["name"]
79+
args = data["arguments"]
80+
if not isinstance(args, str):
81+
args = json.dumps(args, ensure_ascii=False)
82+
83+
tool_calls = [
84+
ToolCall(
85+
type="function",
86+
function=FunctionCall(
87+
name=name,
88+
arguments=args,
89+
),
90+
)
91+
]
92+
prefix = model_output[: match.start()]
93+
content = prefix.rstrip() if prefix and prefix.strip() else None
94+
95+
return ExtractedToolCallInformation(
96+
tools_called=True,
97+
tool_calls=tool_calls,
98+
content=content,
99+
)
100+
101+
def extract_tool_calls_streaming(
102+
self,
103+
previous_text: str,
104+
current_text: str,
105+
delta_text: str,
106+
previous_token_ids: Sequence[int],
107+
current_token_ids: Sequence[int],
108+
delta_token_ids: Sequence[int],
109+
request: ChatCompletionRequest,
110+
) -> DeltaMessage | None:
111+
func_name = None
112+
cur_args = None
113+
if not self.tool_started:
114+
match = REGEX_FUNCTION_CALL.search(current_text)
115+
if match:
116+
self.tool_started = True
117+
self.content_buffer = ""
118+
else:
119+
self.content_buffer += delta_text
120+
clean_buffer = self.content_buffer.lstrip()
121+
is_prefix = self.trigger_start.startswith(clean_buffer)
122+
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
123+
if is_prefix or starts_with_trigger:
124+
return None
125+
else:
126+
flush_text = self.content_buffer
127+
self.content_buffer = ""
128+
return DeltaMessage(content=flush_text)
129+
130+
match = REGEX_FUNCTION_CALL.search(current_text)
131+
if not match:
132+
return None
133+
json_tail = match.group(1).strip()
134+
name_match = NAME_REGEX.search(json_tail)
135+
if name_match:
136+
func_name = name_match.group(1)
137+
args_match = ARGS_REGEX.search(json_tail)
138+
if args_match:
139+
cur_args = args_match.group(1).strip()
140+
if cur_args.endswith("}"): # last '}' end of json
141+
try:
142+
candidate = cur_args[:-1].strip()
143+
json.loads(candidate)
144+
cur_args = candidate
145+
except json.JSONDecodeError:
146+
pass
147+
if not self.prev_tool_call_arr:
148+
self.prev_tool_call_arr.append({})
149+
if not self.tool_name_sent:
150+
if not func_name:
151+
return None
152+
self.tool_name_sent = True
153+
self.tool_id = make_tool_call_id()
154+
self.prev_tool_call_arr[0]["name"] = func_name
155+
return DeltaMessage(
156+
tool_calls=[
157+
DeltaToolCall(
158+
index=0,
159+
id=self.tool_id,
160+
type="function",
161+
function=DeltaFunctionCall(
162+
name=func_name,
163+
).model_dump(exclude_none=True),
164+
)
165+
],
166+
content=None,
167+
)
168+
if cur_args is None:
169+
return None
170+
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
171+
if not prev_args:
172+
delta_args = cur_args
173+
elif cur_args.startswith(prev_args):
174+
delta_args = cur_args[len(prev_args) :]
175+
else:
176+
return None
177+
if not delta_args:
178+
return None
179+
self.prev_tool_call_arr[0]["arguments"] = cur_args
180+
return DeltaMessage(
181+
tool_calls=[
182+
DeltaToolCall(
183+
index=0,
184+
function=DeltaFunctionCall(
185+
arguments=delta_args,
186+
).model_dump(exclude_none=True),
187+
)
188+
],
189+
content=None,
190+
)

0 commit comments

Comments
 (0)