Skip to content

Commit

Permalink
[BugFix] Enforce Mistral ToolCall id constraint when using the Mistra…
Browse files Browse the repository at this point in the history
…l tool call parser (#9020)
  • Loading branch information
gcalmettes authored Oct 3, 2024
1 parent 01843c8 commit 83caf35
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/tool_use/test_parallel_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) > 16
assert len(tool_call.id) >= 9

# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
Expand Down Expand Up @@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) > 16))
and (len(tool_call.id) >= 9))

# if parts of the function start being streamed
if tool_call.function:
Expand Down
4 changes: 2 additions & 2 deletions tests/tool_use/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert tool_calls[0].type == 'function'
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) > 16
assert len(tool_calls[0].id) >= 9

# make sure the weather tool was called (classic example) with arguments
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
Expand Down Expand Up @@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):

assert finish_reason_count == 1
assert role_name == 'assistant'
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)

# validate the name and arguments
assert function_name == WEATHER_TOOL["function"]["name"]
Expand Down
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field

from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -19,6 +22,19 @@

logger = init_logger(__name__)

ALPHANUMERIC = ascii_letters + digits


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class MistralToolParser(ToolParser):
"""
Expand Down Expand Up @@ -71,8 +87,8 @@ def extract_tool_calls(self,
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
Expand Down

0 comments on commit 83caf35

Please sign in to comment.