Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Feature] Add jamba tool parser #9154

Merged
merged 22 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cbd955a
first working version of jamba tool parsing
tomeras91 Sep 26, 2024
1a8c4e1
lint and format
tomeras91 Sep 26, 2024
0da420d
fix: We don't want to add content if it's an empty string
tomeras91 Oct 1, 2024
310535c
add initial tests for jamba tool parser
tomeras91 Oct 1, 2024
f5c9d09
reduce code duplication with use of parametrize
tomeras91 Oct 1, 2024
6b04e35
fix model outputs to match jamba expected output
tomeras91 Oct 1, 2024
c25cd51
add tests for jamba tool parsing with streaming
tomeras91 Oct 1, 2024
d551be0
Merge branch 'main' into add-jamba-tool-parser
tomeras91 Oct 8, 2024
d31e688
adjust JambaToolParser to changes in upstream
tomeras91 Oct 8, 2024
6a27eb3
Add adjust_request function to JambaToolParser since we need to set s…
tomeras91 Oct 8, 2024
bc16953
update comments and remove unused code
tomeras91 Oct 8, 2024
25d839d
lint & format + adjust tests to new tool parser API
tomeras91 Oct 8, 2024
16542bc
dummy for build
tomeras91 Oct 9, 2024
2a25f10
Revert "dummy for build"
tomeras91 Oct 9, 2024
a935865
Merge branch 'main' into add-jamba-tool-parser
DarkLight1337 Oct 9, 2024
3c757c5
Use #9188 and improve validation
DarkLight1337 Oct 9, 2024
0db1408
removed done TODO
tomeras91 Oct 10, 2024
e5c2878
Merge branch 'add-jamba-tool-parser' of github.com:tomeras91/vllm int…
tomeras91 Oct 10, 2024
20aeb6d
Added Jamba tool calling to docs
tomeras91 Oct 17, 2024
54efc40
Apply #9461
DarkLight1337 Oct 18, 2024
ae9a0b7
Trigger build with fix typo
DarkLight1337 Oct 18, 2024
d5fefe9
Fix missing option
DarkLight1337 Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint & format + adjust tests to new tool parser API
  • Loading branch information
tomeras91 committed Oct 8, 2024
commit 25d839d61eec05da46a7cc3393306c5e866c7880
197 changes: 132 additions & 65 deletions tests/tool_use/test_jamba_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from typing import Dict, List, Optional, Generator
from typing import Generator, List, Optional

import partial_json_parser
import pytest
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, DeltaMessage
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import get_tokenizer, AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer

MODEL = "ai21labs/Jamba-tiny-dev"

Expand All @@ -27,16 +28,20 @@ def assert_tool_calls(actual_tool_calls: List[ToolCall],
expected_tool_calls: List[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)

for actual_tool_call, expected_tool_call in zip(actual_tool_calls, expected_tool_calls):
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16

assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function


def stream_delta_message_generator(jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
def stream_delta_message_generator(
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer,
model_output: str) -> Generator[DeltaMessage, None, None]:
all_token_ids = jamba_tokenizer.encode(model_output,
add_special_tokens=False)

previous_text = ""
previous_tokens = None
Expand All @@ -45,17 +50,18 @@ def stream_delta_message_generator(jamba_tool_parser: JambaToolParser, jamba_tok
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i+1]

new_tokens, delta_text, new_prefix_offset, new_read_offset = detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
current_token_ids = all_token_ids[:i + 1]

(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=jamba_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)

current_text = previous_text + delta_text

Expand All @@ -66,20 +72,23 @@ def stream_delta_message_generator(jamba_tool_parser: JambaToolParser, jamba_tok
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message

previous_text = current_text
previous_tokens = previous_tokens + new_tokens if previous_tokens else new_tokens
previous_tokens = previous_tokens + new_tokens if previous_tokens\
else new_tokens
prefix_offset = new_prefix_offset
read_offset = new_read_offset


def test_extract_tool_calls_no_tools(jamba_tool_parser):
model_output = "This is a test"
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(model_output)
assert extracted_tool_calls.tools_called == False
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output

Expand All @@ -93,26 +102,55 @@ def test_extract_tool_calls_no_tools(jamba_tool_parser):
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"})))],
None
),
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
None),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"})))],
" Sure! let me call the tool for you."
),
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" Sure! let me call the tool for you."),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"}))),
ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Orlando", "state": "FL", "unit": "fahrenheit"})))],
None
)
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
],
None)
],
)
def test_extract_tool_calls(jamba_tool_parser, model_output, expected_tool_calls, expected_content):
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(model_output)
assert extracted_tool_calls.tools_called == True
def test_extract_tool_calls(jamba_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called

assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

Expand All @@ -128,37 +166,63 @@ def test_extract_tool_calls(jamba_tool_parser, model_output, expected_tool_calls
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
('''This is a test''', [], '''This is a test'''),
(
'''This is a test''',
[],
'''This is a test'''
),
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" "),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"})))],
" "
),
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
})))
],
" Sure! let me call the tool for you."),
(
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"})))],
" Sure! let me call the tool for you."
),
(
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''',
[ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Dallas", "state": "TX", "unit": "fahrenheit"}))),
ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps({"city": "Orlando", "state": "FL", "unit": "fahrenheit"})))],
" "
)
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
[
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit"
}))),
ToolCall(function=FunctionCall(name="get_current_weather",
arguments=json.dumps(
{
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit"
})))
],
" ")
],
)
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, model_output, expected_tool_calls, expected_content):
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output, expected_tool_calls,
expected_content):
other_content: str = ''
function_names: List[str] = []
function_args_strs: List[str] = []
tool_call_idx: int = -1
tool_call_ids: List[Optional[str]] = []

for delta_message in stream_delta_message_generator(jamba_tool_parser, jamba_tokenizer, model_output):
for delta_message in stream_delta_message_generator(
jamba_tool_parser, jamba_tokenizer, model_output):
# role should never be streamed from tool parser
assert not delta_message.role

Expand All @@ -179,9 +243,8 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, model_
tool_call_ids.append(None)

# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
if not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id

# if parts of the function start being streamed
if tool_call.function:
Expand All @@ -200,9 +263,13 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, model_

assert other_content == expected_content

actual_tool_calls = [ToolCall(id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(function_args_str, Allow.OBJ | Allow.STR)))
for tool_call_id, function_name, function_args_str in zip(tool_call_ids, function_names, function_args_strs)]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
actual_tool_calls = [
ToolCall(id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR)))
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser

__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser"
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
"JambaToolParser"
]
11 changes: 6 additions & 5 deletions vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import partial_json_parser
from partial_json_parser.core.options import Allow

from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall, ChatCompletionRequest)
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
Expand Down Expand Up @@ -68,8 +69,7 @@ def adjust_request(
return request

def extract_tool_calls(
self,
model_output: str,
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:

# sanity check; avoid unnecessary processing
Expand Down Expand Up @@ -103,7 +103,8 @@ def extract_tool_calls(
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if (len(content)>0 and content != " ") else None)
content=content if
(len(content) > 0 and content != " ") else None)

except Exception as e:
logger.error("Error in extracting tool call from response %s",
Expand Down
Loading